Compare commits

...

37 Commits

Author SHA1 Message Date
Aman Gupta 8c146a8366 DeepSeek V4 (#24162)
* convert: add dsv4 conversion

* add basic setup

* add llm_graph_input_dsv4

* add save-load state

* add sinkhorn eps - correction by @fairydreaming

* add rope fix

* cleanup dead code

* fix bugs

* support pro model: added by @fairydreaming

* remove redundant V cache

* Chat template

* remove debugging leftovers

* Add mechanism for inlining templates based on architecture

* s/deepseek-v4-flash/deepseek4/g

* s/deepseek-v4-flash/deepseek4/g continued

* enable graph reuse

* enable FA

* fix test llama archs

* rename

* compatibility with antirez ds4 GGUFs

* simplified set_gguf_parameters() by calling super class method, replaced moe.score_func with expert_gating_func.

* reserve worst-case kv-cache

* revert max split inputs

* address review comments

* add padding to enable FA

* pad only the final value of plan.n_kv to 256

* remove built-in cpp chat template

* cont: remove cpp built-in template

* rm outdated test

* replace ggml_view_3d() with ggml_reshape_3d()

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* only support n_seq=1 for now

* remove unused var

* cont: remove unused var

* use scale bias

* use correct ptr for can_reuse

* remove gen-chat-inline-templates.py

* simplify graph reuse

* cont: cleanup

* remove unused inputs

* enable partial checkpointing

* add correct shape for kq_mask + set llama_model_n_swa to 0 for dsv4

* precompute source_idx + add comment about dummy write

* support multi-seq

* remove restored_trim_pos

* use split_equal when possible

* fix indent

* address review comments

* use LLM_KV

* fix ci

---------

Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: fairydreaming <166155368+fairydreaming@users.noreply.github.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-29 16:58:51 +08:00
seryogakovalyov 6cb18b2f2e tools/ui: restore Tailwind scanning in ignored worktrees (#24879) 2026-06-29 10:55:52 +02:00
o7si 277a105dc8 common : remove unused regex-partial (#25118) 2026-06-29 08:48:39 +02:00
Xuan-Son Nguyen b3fed31b99 jinja, chat: add --reasoning-preserve flag (#25105)
* jinja, chat: add --reasoning-preserve flag

* correct help message
2026-06-28 23:33:51 +02:00
Aleksander Grygier dbdaece23d Revert "ui: fix accessibility for hover-gated interactive elements assisted by claude(in debugging and tests) (#24727)" (#25098) 2026-06-28 21:30:03 +02:00
Pascal 7cb8576e7c ui: fix stop and reasoning skip in single-model mode (#25084) 2026-06-28 21:06:43 +02:00
Ruixiang Wang fa72bc6826 dflash: refactor draft model conversion (#25110)
* dflash: refactor draft model conversion

* apply fix for eagle3 convert
2026-06-28 20:31:48 +02:00
Aldehir Rojas c818263f2a chat : implement minicpm5 parser (#24889)
* Add minicpm5 tool call parser

* Refactor MiniCPM5 PEG parser per review feedback

* Fix jinja min/max API to match Jinja2

* modify by review

* MiniCPM5: use autoparser for XML tool calls and fix grammar preserved-token triggers

* MiniCPM5: fix streaming tool-arg placeholder and remove alt XML markers

* skip min/max attribute tests in -py mode

* test-jinja: use real expected output for min/max attribute tests

* MiniCPM5: revert shared mapper and history fallbacks per review

Drop streaming tool-arg placeholder workarounds from the generic PEG
mapper and restore strict tool-call argument JSON parsing so MiniCPM5
support stays limited to autoparser/diff-analyzer changes.

* chat : refactor minicpm5 back to dedicated parser

* cont : simplify grammar

* cont : refactor

* cont : fixes

* cont : rename template to openbmb-MiniCPM5-1B.jinja

* cont : add message delimiters

* cont : fix tests

---------

Co-authored-by: zhangtao <zhangtao2@modelbest.cn>
Co-authored-by: 张涛 <>
2026-06-28 16:53:32 +02:00
Xuan-Son Nguyen f68a788b0b jinja: add --dump-prog for debugging (#25086)
* jinja: add --dump-prog for debugging

* Update common/jinja/runtime.cpp

Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>

---------

Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
2026-06-28 15:50:31 +02:00
Ruixiang Wang d1b34251bc spec : add DFlash support (#22105)
* spec: add DFlash v2 support

* dflash: support sliding window attention per layer_types

* docs: add dflash section

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2026-06-28 16:01:34 +03:00
Adrien Gallouët c1a1c8ee94 common : allow --offline in llama download (#25091)
Expose the existing --offline flag to `llama download` so a script can
run it to check whether a model is already cached and ready to be served
without touching the network.

Also fix a latent use-after-free in the URL-task on_done callback:
first_path is block-scoped and was captured by reference, but invoked
after the block ends.

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-28 12:34:11 +02:00
Georgi Gerganov 27c8bb4f63 logs : reduce v2 (#25078)
* server : reduce logs

* cont : common

* cont : spec

* cont : CMN_ -> COM_
2026-06-28 08:52:15 +03:00
Hongqiang Wang ebd048fc5e opencl: flash attention improvement (#25069)
* opencl: rework FA kernel for f16 and f32

* opencl: flash-attention prefill prepass kernels

- flash_attn_kv_pad_f16    pads the tail KV tile to a BLOCK_N multiple
- flash_attn_mask_pad_f16  pads the matching mask tile
- flash_attn_blk_f16       classifies each KV tile per query block as
                           fully masked / mixed / fully unmasked, so
                           the main kernel can skip fully-masked tiles
                           and the mask lookup for fully-unmasked ones

* opencl: FA kernels for q4_0 and q8_0

* opencl: `set_rows` for f32 to q8_0/q4_0

* opencl: dequant kernels for q4_0 and q8_0

* opencl: add FA tile tuning table with override

* opencl: wire host side for FA

* opencl: q4_0 MoE tensors are also SOA'ed

* opencl: cosmetic fix

* opencl: refactor, also clarify some code paths in comments

* opencl: fix inifity for `-cl-finite-math-only`

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-06-27 15:36:06 -07:00
Gaurav Garg 0ed235ea2c [CUDA] Added a cudaMemcpy2DAsync fast path to ggml_cuda_cpy (#25057)
* [CUDA] Added a cudaMemcpy2DAsync fast path to ggml_cuda_cpy

Add a CUDA ggml_cpy fast path for same-type, same-shape strided copies that are just 2D pitched block copies.
When tensors are not fully contiguous but each row is contiguous, it now uses cudaMemcpy2DAsync instead of the slow element-wise scalar copy kernel.

This fixes the GDN recurrent snapshot update with -np 4, where rollback slots are separated by cache stride gaps.

* Add new tests that execute the new optimized strided copy path

* Return unsupported for strided copy in OpenVINO, as new tests are failing
2026-06-27 17:46:21 +05:30
Neo Zhang 9bebfcb4bc sycl : fix failed ut cases of norm (#25044) 2026-06-27 12:13:43 +03:00
Ruben Ortlam 0b6529d818 vulkan: fix step operator for 0 input (#25036) 2026-06-27 10:57:31 +02:00
Christian Kastner c299a92c38 binaries : Improve rpc-server and export-graph-ops names. (#25045)
Tests are generally prefixed with -test, so rename export-graph-ops
accordingly.

rpc-server is probably too generic a name for /usr/bin. Because it
should work with any ggml application, it is renamed to ggml-rpc-server.
2026-06-27 10:31:29 +03:00
Sigbjørn Skjæret 0275c0f800 ci : add windows-openvino to check-release (#25022) 2026-06-27 10:30:56 +03:00
Sigbjørn Skjæret 83d385b429 tests : fix test-chat-template --no-common option (#25075) 2026-06-27 10:30:19 +03:00
Adrien Gallouët 050ee92d04 app : allow --version, --licenses & --help (#25054)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-26 23:18:11 +02:00
Andreas Kieslinger 3fc4e10527 sched : reintroduce less synchronizations during split compute (#20793)
* CUDA:  Improve performance via less synchronizations between token (#17795)

* Adds CPU-to-CUDA copy capability to
ggml_backend_cuda_cpy_tensor_async()

* Adds function to relax sync requirements between input copies on
supported backends (CUDA for now)

* Exchanges synchronous copy with async copy function.

* Adds macro guards to allow compilation in non-CUDA builds

* Reworked backend detection in ggml-backend.cpp to avoid linking
conflicts

* Relax requirement of checks in async CUDA copies from backend and buffer type to just buffer type, to avoid linking issues

* Minor cleanup

* Makes opt-in to relax use of explicit syncs more general. Backends like
vulkan which require a synchronization between HtoD copies and graph
execution could also adopt this change now.

* Reintroduces stricter check for CPU->CUDA backend async copy via
GGML_DEVICE_TYPE_CPU.

* Corrects initialization of ggml_backend_sync_mode in
ggml_backend_sched_split initialization

* Simplifies synchronizations to adhere to `saaasg` pattern.

* Apply suggestion from @ggerganov (src->buffer to buf_src)

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Apply suggestion from @ggerganov (src->buffer to buf_src) v2

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Apply suggestions from @johannesgaessler code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Adds single-GPU synchronizations to multi-GPU settings to fix hip backend pipeline parallel bugs.

* Scheduler Hardening: Exclude hip/MUSA from copy_from_host CPU split ->
GPU split optimization

* Scheduler Hardening: Re-adding original additional synchronizations for
non-async backends

* Adds disclaimer to hip/musa exclusion of copy_from_host. Highlights that it is out of
precaution, but that no perf-impact is visible, and that it can be
revisited separately anytime.

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-06-26 17:18:30 +03:00
Adrien Gallouët 5d8ccdf9d1 devops : add llama in all docker images (#25035)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-26 15:15:48 +02:00
Xuan-Son Nguyen 024930c6ad arg: fix handling --spec-draft-hf and --hf-repo-v (#25043)
* arg: fix handling --spec-draft-hf and --hf-repo-v

* fix missing mparams.hf_file
2026-06-26 14:36:03 +02:00
Ravi Panchumarthy 5397c36194 openvino: Update to OV 2026.2.1, self-contained release packages, operator improvements (#24974)
* Update to OV 2026.2.1, Make OV release packages self-contained

* Update to OV 2026.2.1, Make OV release packages self-contained

* OpenVINO Backend: Remove compute_op_type hardcoded sets (#222)

* OpenVINO Backend: Remove compute_op_type hardcoded sets

* revert get_op_type removal

* OpenVINO backend: enable softmax with sink input

* OpenVINO backend: opt mul_mat_id convert process for large size

* OpenVINO backend: Modify add_id to support 2D/4D

* OpenVINO Backend: Add glu_swiglu_oai

* PR review: fix paths

* PR review: fix path consistency

---------

Co-authored-by: Mostafa <mostafas.main.email@gmail.com>
Co-authored-by: Xuejun <Xuejun.Zhai@intel.com>
2026-06-26 15:07:19 +03:00
Georgi Gerganov e7ea94afcb sync : ggml 2026-06-26 15:04:42 +03:00
Georgi Gerganov 96183e9820 ggml : bump version to 0.15.3 (ggml/1550) 2026-06-26 15:04:42 +03:00
nullname 487a6cc164 vulkan: opt mul_mat_vecq for mi50 (#22933) 2026-06-26 13:49:24 +02:00
Jiang, Fish 5a6a0dd7e1 vulkan: add INTEL_XE1 arch enum and enable coopmat1 on Intel Xe-LPG Plus (#24404)
* vulkan: add INTEL_PRE_XE2 arch enum and enable coopmat1 on Intel Xe-LPG Plus (1/3, Xe1-ARLH)

Co-authored-by: Xia, Jie <jie.xia@intel.com>
Co-authored-by: Liu, Russell <russell.liu@intel.com>

* Address comments of bf16 and trailing whitespace

* Rename INTEL_PRE_XE2 to INTEL_XE1 and remove driver workaround

* Add Windows driver check

---------

Co-authored-by: Xia, Jie <jie.xia@intel.com>
Co-authored-by: Liu, Russell <russell.liu@intel.com>
2026-06-26 13:26:22 +02:00
Sanjay Ahari ded1561b42 ui: fix accessibility for hover-gated interactive elements assisted by claude(in debugging and tests) (#24727) 2026-06-26 12:55:38 +02:00
Jeff Bolz 9df06805ee vulkan: Workaround compiler bug in conv2d coopmat2 path (#24924)
* vulkan: Workaround compiler bug in conv2d coopmat2 path

* apply same workaround to CONV_3D

* Apply suggestion from @jeffbolznv
2026-06-26 11:53:32 +02:00
leonardHONG 2f18fe13c5 CUDA: add cublasSgemmBatched mapping for HIP/MUSA vendor headers (#25033) 2026-06-26 11:42:56 +02:00
Tarek Dakhran c16c35b814 ggml-cpu: fix SVE leftover path in ggml_vec_dot_f32 (#24699)
* ggml-cpu: fix SVE leftover path in ggml_vec_dot_f32

2D convolutions with kernel size 9 produced different results on SVE
enabled ARM devices. After debugging it turned out that ggml_vec_dot_f32
was using data from inactive lanes.

Use svmla_f32_m(pg, sum1, ax1, ay1) so inactive lanes retain sum1.

* cont : clean-up

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-26 10:41:56 +03:00
Pascal 1a87dcdc45 server + ui: SSE Replay Buffer (#23226)
* server: SSE replay buffer, survives client disconnect

Opt in on POST /v1/chat/completions when the client sends
X-Stream-Resume: 1 and a non empty X-Conversation-Id. The conv id is
the session identity end to end, no extra opaque token. The drain
runs detached server side and buffers SSE bytes, the generation
survives HTTP disconnect, F5, or lets users switch from iOS Safari
to another app without losing the actively generated response.

Routes:
  GET    /v1/stream/<conv_id>?from=N       replay
  GET    /v1/streams[?conversation_id=X]   list, drives sidebar spinners
  DELETE /v1/stream/<conv_id>              Stop, idempotent

Router parent fans out to children for list and delete, probes on GET
to route to the owner, fans out DELETE on POST so "one session per
conv" holds across model swaps.

WebUI: the layout snapshots /v1/streams at mount and on
visibilitychange, the sidebar reflects live inferences across all
convs. The chat page reattaches on mount, append vs fresh is detected
from existing content so continue mid stream keeps its prefix.

update_slots: on llama_memory_seq_rm refusal at a deep position, full
clear of the seq and reprefill from zero instead of GGML_ABORT.

OAI strict path unchanged when the opt in headers are absent.

* server: create stream session only after post_tasks succeeds

* server, ui: drop X-Stream-Resume, X-Conversation-Id alone enables the replay buffer

* server: drop magic 17, derive the X-Conversation-Id header length from sizeof at build time

* refactor: address review feedback from ngxson

* server-context: cleaning

* server-stream: fix use-after-free on rd

Guard stop_producer with a shared alive flag, flipped by on_stream_end
before rd dies. Prevents a late cancel (session eviction by a later
POST on the same conv_id, or a DELETE arriving after the producer
ended) from touching a destroyed rd.

* ui: fix cross-conversation contamination

Scope streaming flags per conv so one finishing does not unflag the
others, guard discoverActiveStream against concurrent runs to avoid
duplicate attaches, and stop racing syncRemoteRunningStreams for the
sidebar set.

* server-http: keep request alive in detached SSE drain

The response next() lambda may reach into *request via &req long
after on_complete reset the request shared_ptr. Capture request in
the detached thread so it outlives the drain.

* ui: address review feedback from coder543

Forward Authorization to /v1/stream and /v1/streams fetches, the resumable routes
must obey --api-key like the rest of the API.

Wrap reader.read() in a try/catch, the underlying connection drop rejects with
TypeError instead of resolving done=true, treat it as a premature end of stream
so the existing resume loop kicks in.

Freeze the model at session start in chatStreamingStates.model and thread it
through cancel and resume, the dropdown selection may have changed since the
POST and the server side identity is fixed at that time.

* format

* ui: remove unused selectedModelName

* server-stream: poll session->is_cancelled() in stream_aware_should_stop

Address review feedback from coder543. The cancel propagation through
rd.stop() relies on the slot eventually processing the cancel task and
posting a result that notifies the recv condvar, remove_waiting_task_ids
does not notify directly. Add a defensive poll on session->is_cancelled()
so the producer-side next() loop exits on its next iteration after
cancel() without waiting for the cancel task to round trip through a slot.

* server-stream, ui: replace GET /v1/streams with POST /v1/streams/lookup

Address review feedback from coder543. Listing live sessions leaks the
conversation_id of every concurrent user, which defeats the random UUID
unguessability. The new route takes {conversation_ids: [...]} in the
body and returns matches only for the ids the caller already owns, so
foreign UUIDs stay private. The router fans out the same POST to every
child and aggregates, the WebUI passes the convs visible in its sidebar.

* ui: read conv ids from IndexedDB in syncRemoteRunningStreams

The conversations store is not hydrated yet at +layout onMount, so the
sidebar spinners stayed off for background convs until the user clicked
on them. Read straight from the DB to dodge the init race.

* server-models: deduplicate stream lookup timeouts behind one constant

* ui: extract visibility kick grace into a stream constant, bump to 1000 ms

* make it safer & more simple

* server-stream: survive client disconnect via stream_pipe::finish_producer

After the RAII rewrite the generation stopped the moment the client
disconnected. httplib bails its content provider on the is_peer_alive
check at the top of write_content_chunked, so returning true from the
provider never keeps it producing: the response resets, rd is destroyed
and its task gets cancelled.

Reinstate the disconnect survival inside the pipe. stream_pipe gains
finish_producer, which pumps the response next() into the ring buffer
until the generation ends, and mark_producer_done for the clean wire
end. server-http only triggers them: mark before sink.done on a clean
close, finish in on_complete when the peer left early. No detach, no
stream logic in server-http beyond the trigger, and the strict OAI path
is untouched when no pipe is attached.

Known limitation: finish_producer pumps synchronously on the http
worker, so a disconnected stream keeps its worker busy until the
generation ends. A follow-up will move the drain off the http worker so
no worker is held.

* server-stream: drain disconnected streams on a manager owned thread

The previous commit pumped the post disconnect drain synchronously in
on_complete, on the http worker, so a disconnected stream kept its
worker busy until the generation ended. Under a wave of reloads or tab
closes that pins workers from the pool.

Move the drain off the http worker. on_complete now hands the response
to stream_session_manager::adopt_orphan, which pumps it to completion on
a manager owned thread and releases the worker at once. One thread per
disconnected stream still generating, stored in a list, joined and
reaped on the next adopt, by the GC, and at shutdown. No detach, the
thread lifecycle is fully owned by the manager. needs_drain gates the
handoff so a cleanly finished stream never spawns a thread, and the
strict OAI path stays untouched when no pipe is attached.

stop_gc now cancels sessions before finalizing them, so an in flight
drain sees is_cancelled and exits instead of blocking the shutdown join
until the generation ends naturally.

* ui: add missing JSDoc

* server-stream: drain on the http worker, drop the manager thread

Address @ngxson review: httplib runs a large dynamic pool and a worker
blocked in next() sits on a condvar instead of burning cpu, so draining
the rest of the generation on that worker is fine and much simpler than
a dedicated thread.

on_complete calls finish_producer directly again. Removes adopt_orphan,
the orphan thread list and its reaping, the stop_gc session cancel that
only existed to unblock those threads, and the now dead drain_shutdown
flag.

* server-stream: split stream_pipe into producer and consumer classes

Address @ngxson review: one class covering both ends was messy. stream_pipe
is now a base holding the session and is_cancelled, with stream_pipe_producer
(write, mark_producer_done, finish_producer, cleanup, finalizes on destruct)
and stream_pipe_consumer (read only, no finalize) deriving from it.

Drops the is_producer_ discriminator and its runtime guards, the type now
encodes the role. res.spipe is retyped to shared_ptr<stream_pipe_producer>
since it is only ever a producer. No behavior change.

* server-stream: rename producer methods to unix pipe semantics

Address @ngxson review: mark_producer_done becomes done(), finish_producer
becomes close(), matching a unix pipe write end. The producer_done_ member
follows as done_. write() is unchanged. No behavior change.

* server, ui: route resumable streams via a conv map, persist resume identity

Address ngxson review: drop the polling probe, proxy_post records a conv_id ->
model map and the stream routes resolve the owning child with one lookup. The
map is the single source of truth, the ::model suffix stays for child session
uniqueness but the router never parses it.

UI: the server keys a session by the POST time identity (conv::model), but reload
probed with the bare conv id and missed model tagged sessions, so F5 stopped the
stream and sidebar spinners stayed off. Persist the model and rebuild the exact
identity on resume, single conv and bulk sidebar both send it.

Add unit coverage for the identity round trip.

* ui: resolve continue target by id to stop cross-conversation flash on switch

* ui: skip stream resume when the abort is intentional

* server: move the conv id to model map into a self contained tracker

Address review from ngxson: server_models held two mutexes side by side, the
global one and a bare conv_model_mu guarding a loose map, which made the locking
hard to follow. Wrap the map and its lock in a small conv_model_tracker struct
that owns its mutex, one mutex per struct. The remember, lookup and forget
methods move inline into the tracker, server_models exposes a single conv_models
member and the routes call models.conv_models.lookup and friends. No behavior
change, the map stays the single source of truth for routing resumable streams
to a child.

* ui: replace stream magic values with enums and shared constants

Address review from allozaur: lift the inline literals around the resumable
stream code into named symbols so the intent is explicit and reusable.

* ui: fold the stream resume and discovery helpers into ChatService

Address review from allozaur: drop the two standalone stream-*.service files.
They were used only by the chat service and store, carried no shared state, and
did not follow the static class pattern the other services use, so a separate
abstraction was not warranted. Move the helpers onto ChatService as static
methods. No behavior change, tests now exercise them through ChatService.

* docs: document the SSE replay buffer in server README-dev

Add the resumable streaming section, list stream_session_manager in the
backend component inventory, and link PR 23226 in the related PRs.

* ui: align attachServerStream call with onCompletionId param in handleStreamResponse

* server-http: rename del_ to del to match get and post

* ui: address review feedback from allozaur

* ui: drop duplicate SSE constants, keep sse.ts canonical

* ui: use svelte:document for the visibilitychange listener

address review from allozaur: replace the manual document.addEventListener
in onMount with a declarative <svelte:document onvisibilitychange>. svelte
handles attach, detach and SSR, so the typeof document guard and the onMount
cleanup go away. onMount keeps only the first load snapshot.

* server: trim redundant stream drain comments

Address review from ngxson

* server: balance and clean up stream comments

remove redundant comments and tighten the verbose ones across the resumable
stream code, keeping the concurrency and lifetime rationale that is not obvious
from the code. also fix two stale comments in server.cpp and server-models.h
that still described the old ::model suffix probe and fan out routing, now
replaced by the conv_id -> model map

Address review from ngxson

* ui: balance and clean up stream comments

dedup repeated rationale (frozen conv::model identity, the lookup privacy note,
the abort patterns) down to one canonical spot, tighten the verbose blocks, and
keep the concurrency and resume-offset reasoning. fix stale comments in
stream-identity.ts and chat.service.ts that still described the old loopback
probe and fan out routing, now the conv_id -> model map.

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-06-26 09:31:29 +02:00
Jassieluo e7e3f35090 sycl : clamp softmax input to avoid underflow (#24941) 2026-06-26 15:02:42 +08:00
Xuan-Son Nguyen b11f7c16bc mtmd: add more validations (#25013)
* mtmd: add more validations

* fix

* refactor a bit

* type check for get_arr_int
2026-06-26 08:43:29 +02:00
leonardHONG f818065d75 CUDA: batch out_prod broadcast (dps2>1) path with cublasSgemmBatched (#24426) 2026-06-26 08:51:25 +03:00
Arsen Arutunan 960d628f46 mamba2: remove hardcoded 2x expansion factor and invalid d_inner % d_state check (#23082)
* mamba2: remove hardcoded 2x expansion factor, support any expand value

* mamba2: remove invalid d_inner %% d_state check (unrelated parameters)

* Update convert_hf_to_gguf.py: make expand optional with default 2

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* mamba2: apply expand fix to refactored conversion/mamba.py

* also check for mamba_expand

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
2026-06-26 08:50:54 +03:00
153 changed files with 15279 additions and 1736 deletions
+2 -2
View File
@@ -145,7 +145,7 @@ ENTRYPOINT ["/app/tools.sh"]
# ==============================================================================
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
ENTRYPOINT [ "/app/llama-cli" ]
@@ -156,7 +156,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
HEALTHCHECK --interval=5m CMD [ "curl", "-f", "http://localhost:8080/health" ]
+2 -2
View File
@@ -104,7 +104,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -115,7 +115,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+2 -2
View File
@@ -113,7 +113,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -124,7 +124,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+2 -2
View File
@@ -141,7 +141,7 @@ ENTRYPOINT ["/app/tools.sh"]
FROM base AS light
COPY --from=build /app/lib/ /app
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -153,7 +153,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/lib/ /app
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+2 -2
View File
@@ -115,7 +115,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -126,7 +126,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+8 -8
View File
@@ -1,12 +1,12 @@
ARG OPENVINO_VERSION_MAJOR=2026.2
ARG OPENVINO_VERSION_FULL=2026.2.0.21903.52ddc073857
ARG OPENVINO_VERSION_MAJOR=2026.2.1
ARG OPENVINO_VERSION_FULL=2026.2.1.21919.ede283a88e3
ARG UBUNTU_VERSION=24.04
# Intel GPU driver versions. https://github.com/intel/compute-runtime/releases
ARG IGC_VERSION=v2.34.4
ARG IGC_VERSION_FULL=2_2.34.4+21428
ARG COMPUTE_RUNTIME_VERSION=26.18.38308.1
ARG COMPUTE_RUNTIME_VERSION_FULL=26.18.38308.1-0
ARG IGC_VERSION=v2.36.3
ARG IGC_VERSION_FULL=2_2.36.3+21719
ARG COMPUTE_RUNTIME_VERSION=26.22.38646.4
ARG COMPUTE_RUNTIME_VERSION_FULL=26.22.38646.4-0
ARG IGDGMM_VERSION=22.10.0
# Intel NPU driver versions. https://github.com/intel/linux-npu-driver/releases
@@ -214,7 +214,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app/
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app/
WORKDIR /app
@@ -225,7 +225,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app/
COPY --from=build /app/full/llama /app/full/llama-server /app/
WORKDIR /app
+2 -2
View File
@@ -127,7 +127,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -138,7 +138,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+2 -2
View File
@@ -124,7 +124,7 @@ WORKDIR /llama.cpp/bin
# Copy llama.cpp binaries and libraries
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin/llama-completion /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama /llama.cpp/bin/llama-cli /llama.cpp/bin/llama-completion /llama.cpp/bin
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
@@ -138,7 +138,7 @@ WORKDIR /llama.cpp/bin
# Copy llama.cpp binaries and libraries
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama /llama.cpp/bin/llama-server /llama.cpp/bin
EXPOSE 8080
+2 -2
View File
@@ -107,7 +107,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -118,7 +118,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+2 -2
View File
@@ -97,7 +97,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -108,7 +108,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+4 -4
View File
@@ -68,8 +68,8 @@ jobs:
env:
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Clone
@@ -96,8 +96,8 @@ jobs:
env:
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Clone
+4 -4
View File
@@ -39,8 +39,8 @@ jobs:
env:
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Clone
@@ -96,8 +96,8 @@ jobs:
env:
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Clone
+2 -2
View File
@@ -266,8 +266,8 @@ jobs:
env:
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Clone
+58 -11
View File
@@ -446,8 +446,8 @@ jobs:
env:
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Set OpenVINO version output
@@ -506,8 +506,11 @@ jobs:
cmake -B build/ReleaseOV -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_OPENVINO=ON \
-DHF_UI_VERSION=${{ needs.get-version.outputs.ui_version }}
cmake --build build/ReleaseOV --config Release -j $(nproc)
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
-DHF_UI_VERSION=${{ needs.get-version.outputs.ui_version }} \
${{ env.CMAKE_ARGS }}
cmake --build build/ReleaseOV --config Release --parallel
- name: ccache-clear
uses: ./.github/actions/ccache-clear
@@ -521,8 +524,26 @@ jobs:
- name: Pack artifacts
id: pack_artifacts
run: |
cp LICENSE ./build/ReleaseOV/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz --transform "s,^\.,llama-${{ steps.tag.outputs.name }}," -C ./build/ReleaseOV/bin .
dest=./build/ReleaseOV/bin
OPENVINO_ROOT=./openvino_toolkit
ov_lib="$OPENVINO_ROOT/runtime/lib/intel64"
# Bundle OpenVINO runtime libs + TBB. Binaries built with RPATH=$ORIGIN
# load these siblings without setupvars.sh / LD_LIBRARY_PATH.
cp -P "$ov_lib"/libopenvino.so* \
"$ov_lib"/libopenvino_c.so* \
"$ov_lib"/libopenvino_*_plugin.so \
"$ov_lib"/libopenvino_intel_npu_compiler*.so \
"$OPENVINO_ROOT"/runtime/3rdparty/tbb/lib/*.so* \
"$dest"
cp -P /usr/lib/x86_64-linux-gnu/libOpenCL.so.1* "$dest" 2>/dev/null || true
cp "$ov_lib"/cache.json "$dest" 2>/dev/null || true
# OpenVINO licensing
cp -r "$OPENVINO_ROOT"/docs/licensing "$dest"/openvino-licensing
cp LICENSE "$dest"
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz --transform "s,^\.,llama-${{ steps.tag.outputs.name }}," -C "$dest" .
- name: Upload artifacts
uses: actions/upload-artifact@v6
@@ -531,6 +552,9 @@ jobs:
name: llama-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz
windows-openvino:
needs: [check-release]
if: ${{ needs.check-release.outputs.should_release == 'true' }}
runs-on: windows-2022
outputs:
@@ -538,8 +562,8 @@ jobs:
env:
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Set OpenVINO version output
@@ -607,7 +631,9 @@ jobs:
-A x64 ^
-DCMAKE_BUILD_TYPE=Release ^
-DGGML_OPENVINO=ON ^
-DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake
-DLLAMA_BUILD_BORINGSSL=ON ^
-DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake ^
${{ env.CMAKE_ARGS }}
cmake --build build\ReleaseOV --config Release -- /m
@@ -624,8 +650,29 @@ jobs:
id: pack_artifacts
shell: powershell
run: |
Copy-Item LICENSE .\build\ReleaseOV\bin\
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.zip .\build\ReleaseOV\bin\*
# Locate the extracted OpenVINO toolkit root (same pattern as the Build step).
$OPENVINO_ROOT = (Get-ChildItem -Directory openvino_toolkit | Select-Object -First 1).FullName
if (-not $OPENVINO_ROOT) {
Write-Error "OpenVINO toolkit folder not found under .\openvino_toolkit"
exit 1
}
$dest = ".\build\ReleaseOV\bin\Release"
$ovBin = Join-Path $OPENVINO_ROOT 'runtime\bin\intel64\Release'
Copy-Item -Path (Join-Path $ovBin '*.dll') -Destination $dest -Force
Copy-Item -Path (Join-Path $ovBin 'cache.json') -Destination $dest -Force
$tbbBin = Join-Path $OPENVINO_ROOT 'runtime\3rdparty\tbb\bin'
Copy-Item -Path (Join-Path $tbbBin 'tbb*.dll') -Destination $dest -Force
# OpenVINO licensing
$licensingDest = Join-Path $dest 'openvino-licensing'
New-Item -ItemType Directory -Force -Path $licensingDest | Out-Null
Copy-Item -Path (Join-Path $OPENVINO_ROOT 'docs\licensing\*') -Destination $licensingDest -Recurse -Force
Copy-Item LICENSE $dest
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.zip $dest\*
- name: Upload artifacts
uses: actions/upload-artifact@v6
+1 -1
View File
@@ -80,7 +80,7 @@ To protect sensitive data from potential leaks or unauthorized access, it is cru
### Untrusted environments or networks
If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:
* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
* Do not use the RPC backend, [ggml-rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value.
* Encrypt your data if sending it over the network.
+8 -4
View File
@@ -50,6 +50,7 @@ struct command {
std::vector<std::string> aliases;
bool hidden;
int (*func)(int, char **);
bool flags = false; // allow --name
};
#ifdef LLAMA_INSTALL_BUILD
@@ -69,9 +70,9 @@ static const command cmds[] = {
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
{"quantize", "Quantize a model", {}, true, llama_quantize },
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
{"version", "Show version", {}, false, version },
{"licenses", "Show third-party licenses", {"credits"}, false, licenses },
{"help", "Show available commands", {}, false, help },
{"version", "Show version", {}, false, version, true },
{"licenses", "Show third-party licenses", {"credits"}, false, licenses, true },
{"help", "Show available commands", {}, false, help, true },
};
#undef UPDATE_HIDDEN
@@ -108,7 +109,10 @@ static int help(int argc, char ** argv) {
return 0;
}
static bool matches(const std::string & arg, const command & cmd) {
static bool matches(std::string arg, const command & cmd) {
if (cmd.flags && arg.size() > 2 && arg[0] == '-' && arg[1] == '-') {
arg.erase(0, 2);
}
if (arg == cmd.name) {
return true;
}
-2
View File
@@ -94,10 +94,8 @@ add_library(${TARGET}
peg-parser.h
preset.cpp
preset.h
regex-partial.cpp
reasoning-budget.cpp
reasoning-budget.h
regex-partial.h
sampling.cpp
sampling.h
speculative.cpp
+61 -9
View File
@@ -352,6 +352,8 @@ static std::string get_default_local_path(const std::string & url) {
common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex) {
common_download_hf_plan plan;
common_download_hf_plan plan_spec;
common_download_hf_plan plan_voc;
common_download_opts opts;
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
@@ -377,7 +379,15 @@ common_models_handler common_models_handler_init(const common_params & params, l
plan = common_download_get_hf_plan(params.model, opts);
}
return common_models_handler{plan, opts};
if (!params.speculative.draft.mparams.hf_repo.empty()) {
plan_spec = common_download_get_hf_plan(params.speculative.draft.mparams, opts);
}
if (!params.vocoder.model.hf_repo.empty()) {
plan_voc = common_download_get_hf_plan(params.vocoder.model, opts);
}
return common_models_handler{plan, plan_spec, plan_voc, opts};
}
bool common_models_handler_is_preset_repo(const common_models_handler & handler) {
@@ -425,7 +435,9 @@ static std::vector<common_download_task> build_url_tasks(const common_params_mod
void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback) {
std::vector<common_download_task> tasks;
auto & plan = handler.plan;
auto & plan = handler.plan;
auto & plan_spec = handler.plan_spec;
auto & plan_voc = handler.plan_voc;
auto opts = handler.opts; // copy
opts.callback = callback;
@@ -455,7 +467,7 @@ void common_models_handler_apply(common_models_handler & handler, common_params
// the first part is what gets loaded, so point params.model.path at it
if (!url_tasks.empty()) {
std::string first_path = url_tasks.front().local_path;
url_tasks.front().on_done = [&]() { params.model.path = first_path; };
url_tasks.front().on_done = [&, first_path]() { params.model.path = first_path; };
}
for (auto & task : url_tasks) {
tasks.push_back(std::move(task));
@@ -484,19 +496,22 @@ void common_models_handler_apply(common_models_handler & handler, common_params
}
// handle hf_plan tasks
if (!plan.model_files.empty()) {
for (size_t i = 0; i < plan.model_files.size(); ++i) {
auto & model_file = plan.model_files[i];
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files, common_params_model & model) {
for (size_t i = 0; i < model_files.size(); ++i) {
auto & model_file = model_files[i];
bool is_first = (i == 0);
tasks.emplace_back(model_file, opts, [&, is_first]() {
if (is_first) {
// only use first part as model path
params.model.path = hf_cache::finalize_file(model_file);
model.path = hf_cache::finalize_file(model_file);
} else {
hf_cache::finalize_file(model_file);
}
});
}
};
if (!plan.model_files.empty()) {
add_tasks(plan.model_files, params.model);
}
if (!plan.mmproj.local_path.empty()) {
tasks.emplace_back(plan.mmproj, opts, [&]() {
@@ -522,9 +537,31 @@ void common_models_handler_apply(common_models_handler & handler, common_params
});
}
// handle plan_spec (e.g. --spec-draft-hf)
if (!plan_spec.model_files.empty()) {
add_tasks(plan_spec.model_files, params.speculative.draft.mparams);
}
// handle vocoder plan (e.g. --hf-repo-v)
if (!plan_voc.model_files.empty()) {
add_tasks(plan_voc.model_files, params.vocoder.model);
}
// run all tasks in parallel
if (!params.offline) {
common_download_run_tasks(tasks);
// if duplicated files are found, only download once (but still call on_done for each task)
std::unordered_map<std::string, common_download_task *> unique_tasks;
for (auto & task : tasks) {
auto it = unique_tasks.find(task.local_path);
if (it == unique_tasks.end()) {
unique_tasks[task.local_path] = &task;
}
}
std::vector<common_download_task> unique_tasks_vec;
for (auto & pair : unique_tasks) {
unique_tasks_vec.push_back(*pair.second);
}
common_download_run_tasks(unique_tasks_vec);
}
// download successful, update params with the downloaded paths
@@ -3259,6 +3296,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.reasoning_budget_message = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
add_opt(common_arg(
{"--reasoning-preserve"},
{"--no-reasoning-preserve"},
"preserve reasoning trace in the full history, not just the last assistant message (default: template default)\n"
"compatible with certain templates having 'supports_preserve_reasoning' capability\n"
"example: https://docs.z.ai/guides/capabilities/thinking-mode#preserved-thinking",
[](common_params & params, bool value) {
if (value) {
params.default_template_kwargs["preserve_reasoning"] = "true";
} else {
params.default_template_kwargs["preserve_reasoning"] = "false";
}
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING_PRESERVE"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
@@ -3434,7 +3485,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.offline = true;
}
).set_env("LLAMA_ARG_OFFLINE"));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_OFFLINE"));
add_opt(common_arg(
{"-lv", "--verbosity", "--log-verbosity"}, "N",
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"
@@ -3711,6 +3762,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.draft.mparams.path = value;
params.speculative.draft.mparams.hf_file = value; // will be used if --spec-draft-hf is set
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(
+2
View File
@@ -133,6 +133,8 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
struct common_models_handler {
common_download_hf_plan plan;
common_download_hf_plan plan_spec;
common_download_hf_plan plan_voc;
common_download_opts opts;
};
+155
View File
@@ -912,6 +912,10 @@ static std::string common_chat_template_direct_apply_impl(
if (inputs.add_generation_prompt) {
inp["add_generation_prompt"] = true;
}
if (inp.contains("preserve_reasoning") && inp["preserve_reasoning"].is_boolean()) {
bool enabled = inp["preserve_reasoning"].get<bool>();
jinja::caps_apply_preserve_reasoning(ctx, enabled);
}
jinja::global_from_json(ctx, inp, inputs.mark_input);
@@ -2376,6 +2380,149 @@ static void func_args_not_string(json & messages) {
}
// MiniCPM5 format:
// - Reasoning: <think>{reasoning}</think> (optional)
// - Tool calls: <function name="foo"><param name="bar">value</param></function>
static common_chat_params common_chat_params_init_minicpm5(const common_chat_template & tmpl,
const autoparser::generation_params & inputs) {
common_chat_params data;
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
data.preserved_tokens = {
"<function",
"<param",
"</function>",
"</param>",
"<think>",
"</think>",
};
data.thinking_start_tag = "<think>";
data.thinking_end_tag = "</think>";
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, "<|im_start|>assistant" },
{ COMMON_CHAT_ROLE_TOOL, "<|im_start|>user\n<tool_response>" },
{ COMMON_CHAT_ROLE_USER, "<|im_start|>user" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|im_start|>system" },
};
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
if (inputs.has_continuation()) {
const auto & msg = inputs.continue_msg;
data.generation_prompt = "<|im_start|>assistant\n<think>\n" + msg.reasoning_content;
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
data.generation_prompt += "\n</think>\n\n" + msg.render_content();
}
data.prompt += data.generation_prompt;
}
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
auto generation_prompt = p.literal("<|im_start|>assistant\n");
auto reasoning = p.eps();
if (extract_reasoning) {
reasoning = ("<think>" << p.reasoning(p.until("</think>")) << "</think>") + p.space();
}
// Response format parser
if (has_response_format) {
return generation_prompt + reasoning + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
}
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
// CDATA lets a value carry characters that would otherwise close the tag (e.g.
// </param>); capture the inner text only, excluding the CDATA markers.
auto string_value = p.choice({
p.literal("<![CDATA[") + p.ac(p.tool_arg_string_value(p.until("]]>")) + p.literal("]]>"), "]]>") + p.tool_arg_close(p.literal("</param>")),
p.negate(p.literal("<![CDATA[")) + p.ac(p.tool_arg_string_value(p.until("</param>")) + p.tool_arg_close(p.literal("</param>")), "</param>")
});
auto tool_choice = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
const std::string name = function.at("name");
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
auto args = p.eps();
if (params.contains("properties") && params.at("properties").is_object() && !params.at("properties").empty()) {
auto schema_info = common_schema_info();
schema_info.resolve_refs(params);
auto arg_choice = p.choice();
for (const auto & [prop_name, prop_schema] : params.at("properties").items()) {
auto value_parser = p.eps();
if (schema_info.resolves_to_string(prop_schema)) {
value_parser = string_value;
} else {
value_parser = p.tool_arg_json_value(
p.schema(p.json(), "tool-" + name + "-arg-" + prop_name + "-schema", prop_schema, false)
) + p.tool_arg_close(p.literal("</param>"));
}
auto arg_rule = p.tool_arg(
p.tool_arg_open(p.literal("<param name=\"") + p.tool_arg_name(p.literal(prop_name)) + p.literal("\">")) +
value_parser
);
arg_choice |= arg_rule;
}
args = p.zero_or_more(arg_choice + p.space());
}
auto tool_parser = p.tool(
p.tool_open(p.literal("<function name=\"") + p.tool_name(p.literal(name)) + p.literal("\">"))
<< p.tool_args(args)
<< p.tool_close(p.literal("</function>")));
tool_choice |= p.rule("tool-" + name, tool_parser);
});
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
auto tool_calls = p.trigger_rule("tool-call", p.repeat(tool_choice + p.space(), 1, max_calls));
auto content = p.content(p.until("<function"));
return generation_prompt + reasoning + content + tool_calls + p.end();
}
return generation_prompt + reasoning + p.content(p.rest()) + p.end();
});
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
builder.resolve_refs(schema);
});
if (has_response_format) {
auto schema = inputs.json_schema;
builder.resolve_refs(schema);
}
parser.build_grammar(builder, data.grammar_lazy);
});
data.grammar_triggers = {
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function" },
};
}
return data;
}
static json common_chat_extra_context() {
json ctx = json::object();
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
@@ -2468,6 +2615,14 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
return common_chat_params_init_gemma4(tmpl, params);
}
// MiniCPM5 - XML tool calls with <function name="..."><param name="...">...</param></function>
if (src.find("Tool usage guidelines:") != std::string::npos &&
src.find("<function name=\"") != std::string::npos &&
src.find("<param name=\"") != std::string::npos) {
LOG_DBG("Using specialized template: MiniCPM5\n");
return common_chat_params_init_minicpm5(tmpl, params);
}
return std::nullopt;
}
+47 -47
View File
@@ -225,7 +225,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
}
if (!SetPriorityClass(GetCurrentProcess(), p)) {
LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
COM_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
return false;
}
@@ -251,7 +251,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
}
if (setpriority(PRIO_PROCESS, 0, p) != 0) {
LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
COM_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
return false;
}
return true;
@@ -284,14 +284,14 @@ void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_para
if (n_set && n_set < cpuparams.n_threads) {
// Not enough set bits, may experience performance issues.
LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
COM_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
}
}
bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
size_t dash_loc = range.find('-');
if (dash_loc == std::string::npos) {
LOG_ERR("Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
COM_ERR("%s", "Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
return false;
}
@@ -303,7 +303,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
} else {
start_i = std::stoull(range.substr(0, dash_loc));
if (start_i >= GGML_MAX_N_THREADS) {
LOG_ERR("Start index out of bounds!\n");
COM_ERR("%s", "Start index out of bounds!\n");
return false;
}
}
@@ -313,7 +313,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
} else {
end_i = std::stoull(range.substr(dash_loc + 1));
if (end_i >= GGML_MAX_N_THREADS) {
LOG_ERR("End index out of bounds!\n");
COM_ERR("%s", "End index out of bounds!\n");
return false;
}
}
@@ -333,7 +333,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
}
size_t num_digits = mask.length() - start_i;
if (num_digits > 128) num_digits = 128;
num_digits = std::min<size_t>(num_digits, 128);
size_t end_i = num_digits + start_i;
@@ -348,7 +348,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
} else if (c >= 'A' && c <= 'F') {
id -= 'A' - 10;
} else {
LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
COM_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
return false;
}
@@ -379,21 +379,21 @@ void common_params_print_info(const common_params & params, bool print_devices)
#else
const char * build_type = " (debug)";
#endif
LOG_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
COM_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
LOG_INF("log_info: verbosity = %d (adjust with the `-lv N` CLI arg)\n", common_log_get_verbosity_thold());
COM_INF("%s: verbosity = %d (adjust with the `-lv N` CLI arg)\n", __func__, common_log_get_verbosity_thold());
// device enumeration creates a primary context on CUDA backends, skip it when the caller does not own any device
if (print_devices) {
LOG_INF("device_info:\n");
COM_TRC("%s", "device_info:\n");
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
LOG_INF(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
COM_TRC(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
}
}
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
COM_TRC("%s\n", common_params_get_system_info(params).c_str());
}
std::string common_params_get_system_info(const common_params & params) {
@@ -660,7 +660,7 @@ void string_process_escapes(std::string & input) {
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
const char * sep = strchr(data, '=');
if (sep == nullptr || sep - data >= 128) {
LOG_ERR("%s: malformed KV override '%s'\n", __func__, data);
COM_ERR("%s: malformed KV override '%s'\n", __func__, data);
return false;
}
llama_model_kv_override kvo;
@@ -683,20 +683,20 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
} else if (std::strcmp(sep, "false") == 0) {
kvo.val_bool = false;
} else {
LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
COM_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
return false;
}
} else if (strncmp(sep, "str:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
if (strlen(sep) > 127) {
LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
COM_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
return false;
}
strncpy(kvo.val_str, sep, 127);
kvo.val_str[127] = '\0';
} else {
LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
COM_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
return false;
}
overrides.emplace_back(std::move(kvo));
@@ -1199,8 +1199,8 @@ common_init_result::common_init_result(common_params & params, bool model_only)
auto cparams = common_context_params_to_llama(params);
if (params.fit_params) {
LOG_INF("%s: fitting params to device memory ...\n", __func__);
LOG_INF("%s: (for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n", __func__);
COM_TRC("%s", "fitting params to device memory ...\n");
COM_TRC("%s", "(for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n");
common_fit_params(params.model.path.c_str(), &mparams, &cparams,
params.tensor_split,
params.tensor_buft_overrides.data(),
@@ -1227,7 +1227,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
llama_adapter_lora_ptr lora;
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str());
COM_ERR("failed to load lora adapter '%s'\n", la.path.c_str());
pimpl->model.reset(model);
return;
}
@@ -1246,14 +1246,14 @@ common_init_result::common_init_result(common_params & params, bool model_only)
common_init_sampler_from_model(model, params.sampling);
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
COM_WRN("%s", "vocab does not have an EOS token, ignoring --ignore-eos\n");
params.sampling.ignore_eos = false;
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_TRC("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
COM_TRC("added %s logit bias = %f\n", common_token_to_piece(vocab, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
@@ -1291,7 +1291,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
return;
}
@@ -1328,7 +1328,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
llama_model * model = res->model();
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
COM_ERR("failed to load model '%s'\n", params.model.path.c_str());
return res;
}
@@ -1338,14 +1338,14 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
llama_context * lctx = res->context();
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
return res;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
COM_WRN("%s", "KV cache shifting is not supported for this context, disabling KV cache shifting\n");
params.ctx_shift = false;
}
@@ -1374,7 +1374,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
bool ok = true;
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
COM_WRN("%s", "vocab does not have a BOS token, reranking will not work\n");
ok = false;
}
@@ -1383,10 +1383,10 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
if (!has_eos && !has_sep && !has_rerank_prompt) {
LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
COM_WRN("%s", "vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n");
ok = false;
} else if (!has_eos) {
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
COM_WRN("%s", "vocab does not have an EOS token, using SEP token as fallback\n");
}
if (!ok) {
@@ -1399,7 +1399,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
}
if (params.warmup) {
LOG_INF("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
COM_TRC("%s", "warming up the model with an empty run - please wait ... (--no-warmup to disable)\n");
std::vector<llama_token> tmp;
llama_token bos = llama_vocab_bos(vocab);
@@ -1473,20 +1473,20 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
COM_ERR("llama_decode() failed: %d\n", ret);
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
goto done;
}
if (llama_n_rs_seq(ctx) > 0) {
LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__);
COM_TRC("%s", "the context supports bounded partial sequence removal\n");
res = COMMON_CONTEXT_SEQ_RM_TYPE_RS;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_TRC("%s: the context does not support partial sequence removal\n", __func__);
COM_TRC("%s", "the context does not support partial sequence removal\n");
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
goto done;
}
@@ -1803,13 +1803,13 @@ static common_control_vector_data common_control_vector_load_one(const common_co
};
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
if (!ctx_gguf) {
LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
COM_ERR("failed to load control vector file from %s\n", load_info.fname.c_str());
return result;
}
int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
if (n_tensors == 0) {
LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
COM_WRN("no direction tensors found in %s\n", load_info.fname.c_str());
}
for (int i = 0; i < n_tensors; i++) {
@@ -1827,23 +1827,23 @@ static common_control_vector_data common_control_vector_load_one(const common_co
}
}
if (layer_idx < 0) {
LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid/unparsable direction tensor layer index in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
} else if (layer_idx == 0) {
LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid (zero) direction tensor layer index in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
if (tensor->type != GGML_TYPE_F32) {
LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid (non-F32) direction tensor type in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
if (ggml_n_dims(tensor) != 1) {
LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
COM_ERR("invalid (non-1D) direction tensor shape in %s\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
@@ -1851,7 +1851,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
if (result.n_embd == -1) {
result.n_embd = ggml_nelements(tensor);
} else if (ggml_nelements(tensor) != result.n_embd) {
LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
COM_ERR("direction tensor in %s does not match previous dimensions\n", load_info.fname.c_str());
result.n_embd = -1;
break;
}
@@ -1868,7 +1868,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
}
if (result.n_embd == -1) {
LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
COM_WRN("skipping %s due to invalid direction tensors\n", load_info.fname.c_str());
result.data.clear();
}
@@ -1889,7 +1889,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
break;
}
if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
COM_ERR("control vectors in %s does not match previous dimensions\n", info.fname.c_str());
result.n_embd = -1;
break;
}
@@ -1905,7 +1905,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
}
if (result.n_embd == -1) {
LOG_ERR("%s: no valid control vector files passed\n", __func__);
COM_ERR("%s", "no valid control vector files passed\n");
result.data.clear();
}
@@ -2016,13 +2016,13 @@ bool common_prompt_batch_decode(
// memory, so we can't just remove the last token from the memory and replay the last token which
// is the reason for this logic.
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_tokens_before_last))) {
LOG_ERR("%s : failed to eval\n", __func__);
COM_ERR("%s", "failed to eval\n");
return false;
}
n_past += n_tokens_before_last;
llama_state_save_file(ctx, state_path.data(), all_tokens.data(), all_tokens.size());
LOG_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
COM_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
llama_token last_token = all_tokens.back();
llama_batch batch = llama_batch_get_one(&last_token, 1);
@@ -2030,13 +2030,13 @@ bool common_prompt_batch_decode(
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval last token\n", __func__);
COM_ERR("%s", "failed to eval last token\n");
return false;
}
n_past++;
} else {
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_new))) {
LOG_ERR("%s : failed to eval\n", __func__);
COM_ERR("%s", "failed to eval\n");
return false;
}
n_past += n_new;
+9 -1
View File
@@ -25,6 +25,13 @@
#define DIRECTORY_SEPARATOR '/'
#endif // _WIN32
#define COM_DBG(fmt, ...) LOG_DBG("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_TRC(fmt, ...) LOG_TRC("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_INF(fmt, ...) LOG_INF("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_WRN(fmt, ...) LOG_WRN("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_ERR(fmt, ...) LOG_ERR("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define COM_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
@@ -162,6 +169,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, // DFlash speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -377,7 +385,7 @@ struct common_params_speculative {
uint32_t need_n_rs_seq() const {
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3 || t == COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH;
});
return needs_rs_seq ? draft.n_max : 0u;
+1 -1
View File
@@ -233,7 +233,7 @@ static void common_params_fit_impl(
sum_projected_used = dmds_full.back().mb.total();
sum_free = dmds_full.back().total;
sum_projected_free = sum_free - sum_projected_used;
LOG_INF("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
LOG_TRC("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
__func__, sum_projected_used/MiB, sum_free/MiB);
if (sum_projected_free >= margins[0]) {
LOG_TRC("%s: will leave %" PRId64 " >= %" PRId64 " MiB of system memory, no changes needed\n",
+44 -23
View File
@@ -16,22 +16,34 @@ using json = nlohmann::ordered_json;
namespace jinja {
using caps_json_fn = std::function<json()>;
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
using caps_ctx_fn = std::function<void(context &)>;
using caps_analyze_fn = std::function<void(bool, value &, value &, const std::string &)>;
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled) {
ctx.set_val("preserve_thinking", mk_val<value_bool>(enabled));
ctx.set_val("clear_thinking", mk_val<value_bool>(!enabled));
ctx.set_val("truncate_history_thinking", mk_val<value_bool>(!enabled));
}
static void caps_try_execute(jinja::program & prog,
const caps_json_fn & messages_fn,
const caps_ctx_fn & ctx_fn,
const caps_json_fn & tools_fn,
const caps_analyze_fn & analyze_fn) {
context ctx;
ctx.is_get_stats = true;
jinja::global_from_json(ctx, json{
{"messages", messages_fn()},
{"tools", tools_fn()},
{"tools", tools_fn ? tools_fn() : json::array()},
{"bos_token", ""},
{"eos_token", ""},
{"add_generation_prompt", true}
}, true);
if (ctx_fn) {
ctx_fn(ctx);
}
auto messages = ctx.get_val("messages");
auto tools = ctx.get_val("tools");
@@ -49,7 +61,7 @@ static void caps_try_execute(jinja::program & prog,
// ignore exceptions during capability analysis
}
analyze_fn(success, messages, tools);
analyze_fn(success, messages, tools, result);
}
// for debugging only
@@ -109,11 +121,9 @@ caps caps_get(jinja::program & prog) {
}
});
},
[&]() {
// tools
return json{nullptr};
},
[&](bool success, value & messages, value &) {
nullptr, // ctx_fn
nullptr, // tools_fn
[&](bool success, value & messages, value &, const std::string &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
@@ -145,11 +155,9 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&]() {
// tools
return json::array();
},
[&](bool, value & messages, value &) {
nullptr, // ctx_fn
nullptr, // tools_fn
[&](bool, value & messages, value &, const std::string &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (!content->stats.used) {
@@ -201,6 +209,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
@@ -224,7 +233,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & tools) {
[&](bool success, value & messages, value & tools, const std::string &) {
if (!success) {
return; // Nothing can be inferred
}
@@ -293,6 +302,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
@@ -316,7 +326,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & tools) {
[&](bool success, value & messages, value & tools, const std::string &) {
if (!success) {
result.supports_tool_calls = false;
result.supports_tools = false;
@@ -394,6 +404,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
nullptr, // ctx_fn
[&]() {
// tools
return json::array({
@@ -417,7 +428,7 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&](bool success, value & messages, value & /*tools*/) {
[&](bool success, value & messages, value &, const std::string &) {
if (!success) {
result.supports_parallel_tool_calls = false;
return;
@@ -438,11 +449,22 @@ caps caps_get(jinja::program & prog) {
JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning");
// case: preserve reasoning content in chat history
const std::string reasoning_placeholder = "<REASONING_CONTENT_PLACEHOLDER>";
caps_try_execute(
prog,
[&]() {
// messages
return json::array({
{
{"role", "user"},
{"content", "User message"}
},
{
{"role", "assistant"},
{"content", "Assistant message"},
// check of reasoning_content deeper in the history, not just the last assistant message
{"reasoning_content", reasoning_placeholder}
},
{
{"role", "user"},
{"content", "User message"}
@@ -458,14 +480,13 @@ caps caps_get(jinja::program & prog) {
},
});
},
[&]() {
// tools
return json::array();
[&](context & ctx) {
caps_apply_preserve_reasoning(ctx, true);
},
[&](bool, value & messages, value &) {
auto & content = messages->at(1)->at("reasoning_content");
caps_print_stats(content, "messages[1].reasoning_content");
if (content->stats.used) {
nullptr, // tools_fn
[&](bool, value &, value &, const std::string & output) {
// note: we cannot use stats here because the reasoning_content may be used for "if" condition test, but not actually outputted in the final result
if (output.find(reasoning_placeholder) != std::string::npos) {
result.supports_preserve_reasoning = true;
}
}
+5 -1
View File
@@ -12,7 +12,9 @@ struct caps {
bool supports_tool_calls = true;
bool supports_system_role = true;
bool supports_parallel_tool_calls = true;
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
// supports preserve reasoning trace in the full history, not just the last assistant message
bool supports_preserve_reasoning = false;
// one of the 2 content capabilities must be true
bool supports_string_content = true;
@@ -29,4 +31,6 @@ struct caps {
caps caps_get(jinja::program & prog);
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled);
} // namespace jinja
+46
View File
@@ -954,4 +954,50 @@ value keyword_argument_expression::execute_impl(context & ctx) {
return mk_val<value_kwarg>(k, v);
}
std::string runtime::debug_dump_program(const program & prog, const std::string & src) {
std::ostringstream oss;
size_t lvl = 0;
context ctx;
ctx.src.reset(new std::string(src));
auto indent = [](size_t lvl) -> std::string {
return std::string(lvl * 2, ' ');
};
ctx.visitor = [&](bool is_leaf, statement * node, std::vector<visitor_pair> children) {
oss << indent(lvl) << node->type() << ":\n";
lvl++;
if (is_leaf) {
const auto & pos = node->pos;
oss << indent(lvl) << "(leaf) at " << get_line_col(src, pos) << " in source:\n";
std::string snippet = peak_source(src, pos);
string_replace_all(snippet, "\n", "\n" + indent(lvl));
oss << indent(lvl) << snippet << "\n";
} else {
for (auto & [label, children_vec] : children) {
oss << indent(lvl) << label << ":\n";
lvl++;
if (children_vec.empty()) {
oss << indent(lvl) << "<empty>\n\n";
} else {
for (auto * child : children_vec) {
if (!child) {
continue;
}
child->visit(ctx);
}
}
lvl--;
}
}
lvl--;
};
for (const auto & stmt : prog.body) {
stmt->visit(ctx);
}
return oss.str();
}
} // namespace jinja
+127
View File
@@ -47,12 +47,19 @@ const T * cast_stmt(const statement_ptr & ptr) {
// not thread-safe
void enable_debug(bool enable);
// for visiting AST nodes
// function signature: void(bool is_leaf, statement * node, pair of <label, children>)
using visitor_pair = std::pair<std::string, std::vector<statement *>>;
using visitor_fn = std::function<void(bool, statement *, std::vector<visitor_pair>)>;
struct context {
std::shared_ptr<std::string> src; // for debugging; use shared_ptr to avoid copying on scope creation
std::time_t current_time; // for functions that need current time
bool is_get_stats = false; // whether to collect stats
visitor_fn visitor;
// src is optional, used for error reporting
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
env = mk_val<value_object>();
@@ -99,6 +106,15 @@ private:
value_object env;
};
// utils for visiting AST nodes
static std::vector<statement *> stmts_to_ptr(const statements & stmts) {
std::vector<statement *> children;
for (const auto & stmt : stmts) {
children.push_back(stmt.get());
}
return children;
}
/**
* Base class for all nodes in the AST.
*/
@@ -106,6 +122,7 @@ struct statement {
size_t pos; // position in source, for debugging
virtual ~statement() = default;
virtual std::string type() const { return "Statement"; }
virtual void visit(context & ctx) { ctx.visitor(true, this, {}); }
// execute_impl must be overridden by derived classes
virtual value execute_impl(context &) { throw_exec_error(); }
@@ -166,6 +183,13 @@ struct if_statement : public statement {
std::string type() const override { return "If"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"test", {test.get()}},
{"body", stmts_to_ptr(body)},
{"alternate", stmts_to_ptr(alternate)}
});
}
};
struct identifier;
@@ -190,6 +214,14 @@ struct for_statement : public statement {
std::string type() const override { return "For"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"loopvar", {loopvar.get()}},
{"iterable", {iterable.get()}},
{"body", stmts_to_ptr(body)},
{"default_block", stmts_to_ptr(default_block)}
});
}
};
struct break_statement : public statement {
@@ -241,6 +273,13 @@ struct set_statement : public statement {
std::string type() const override { return "Set"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"assignee", {assignee.get()}},
{"value", {val.get()}},
{"body", stmts_to_ptr(body)}
});
}
};
struct macro_statement : public statement {
@@ -256,6 +295,13 @@ struct macro_statement : public statement {
std::string type() const override { return "Macro"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"name", {name.get()}},
{"args", stmts_to_ptr(args)},
{"body", stmts_to_ptr(body)}
});
}
};
struct comment_statement : public statement {
@@ -289,6 +335,12 @@ struct member_expression : public expression {
}
std::string type() const override { return "MemberExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"object", {object.get()}},
{"property", {property.get()}}
});
}
};
struct call_expression : public expression {
@@ -302,6 +354,12 @@ struct call_expression : public expression {
}
std::string type() const override { return "CallExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"callee", {callee.get()}},
{"args", stmts_to_ptr(args)}
});
}
};
/**
@@ -405,6 +463,12 @@ struct binary_expression : public expression {
}
std::string type() const override { return "BinaryExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"left", {left.get()}},
{"right", {right.get()}}
});
}
};
/**
@@ -431,6 +495,12 @@ struct filter_expression : public expression {
std::string type() const override { return "FilterExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"operand", {operand.get()}},
{"filter", {filter.get()}}
});
}
};
struct filter_statement : public statement {
@@ -443,6 +513,12 @@ struct filter_statement : public statement {
}
std::string type() const override { return "FilterStatement"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"filter", {filter.get()}},
{"body", stmts_to_ptr(body)}
});
}
};
/**
@@ -468,6 +544,12 @@ struct select_expression : public expression {
}
return lhs->execute_impl(ctx);
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"lhs", {lhs.get()}},
{"test", {test.get()}}
});
}
};
/**
@@ -486,6 +568,12 @@ struct test_expression : public expression {
}
std::string type() const override { return "TestExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"operand", {operand.get()}},
{"test", {test.get()}}
});
}
};
/**
@@ -501,6 +589,11 @@ struct unary_expression : public expression {
}
std::string type() const override { return "UnaryExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"argument", {argument.get()}}
});
}
};
struct slice_expression : public expression {
@@ -518,6 +611,13 @@ struct slice_expression : public expression {
[[noreturn]] value execute_impl(context &) override {
throw std::runtime_error("must be handled by MemberExpression");
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"start_expr", {start_expr.get()}},
{"stop_expr", {stop_expr.get()}},
{"step_expr", {step_expr.get()}}
});
}
};
struct keyword_argument_expression : public expression {
@@ -531,6 +631,12 @@ struct keyword_argument_expression : public expression {
}
std::string type() const override { return "KeywordArgumentExpression"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"key", {key.get()}},
{"val", {val.get()}}
});
}
};
struct spread_expression : public expression {
@@ -539,6 +645,11 @@ struct spread_expression : public expression {
chk_type<expression>(this->argument);
}
std::string type() const override { return "SpreadExpression"; }
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"argument", {argument.get()}}
});
}
};
struct call_statement : public statement {
@@ -553,6 +664,13 @@ struct call_statement : public statement {
}
std::string type() const override { return "CallStatement"; }
value execute_impl(context & ctx) override;
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"call", {call.get()}},
{"caller_args", stmts_to_ptr(caller_args)},
{"body", stmts_to_ptr(body)}
});
}
};
struct ternary_expression : public expression {
@@ -575,6 +693,13 @@ struct ternary_expression : public expression {
return false_expr->execute(ctx);
}
}
void visit(context & ctx) override {
ctx.visitor(false, this, {
{"condition", {condition.get()}},
{"true_expr", {true_expr.get()}},
{"false_expr", {false_expr.get()}}
});
}
};
struct raised_exception : public std::exception {
@@ -648,6 +773,8 @@ struct runtime {
}
return parts;
}
static std::string debug_dump_program(const program & prog, const std::string & src);
};
} // namespace jinja
+44
View File
@@ -1108,6 +1108,50 @@ const func_builtins & value_array_t::get_builtins() const {
std::reverse(arr.begin(), arr.end());
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
}},
{"min", [](const func_args & args) -> value {
args.ensure_count(1, 4);
args.ensure_vals<value_array>();
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
value attribute = args.get_kwarg_or_pos("attribute", 2);
if (!attribute->is_undefined()) {
throw not_implemented_exception("min: attribute not implemented");
}
// FIXME: min is currently always case sensitive
(void) val_case;
const auto & arr = args.get_pos(0)->as_array();
if (arr.empty()) {
return mk_val<value_undefined>();
}
value result = arr[0];
for (size_t i = 1; i < arr.size(); ++i) {
if (value_compare(arr[i], result, value_compare_op::lt)) {
result = arr[i];
}
}
return result;
}},
{"max", [](const func_args & args) -> value {
args.ensure_count(1, 4);
args.ensure_vals<value_array>();
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
value attribute = args.get_kwarg_or_pos("attribute", 2);
if (!attribute->is_undefined()) {
throw not_implemented_exception("max: attribute not implemented");
}
// FIXME: max is currently always case sensitive
(void) val_case;
const auto & arr = args.get_pos(0)->as_array();
if (arr.empty()) {
return mk_val<value_undefined>();
}
value result = arr[0];
for (size_t i = 1; i < arr.size(); ++i) {
if (value_compare(arr[i], result, value_compare_op::gt)) {
result = arr[i];
}
}
return result;
}},
{"unique", array_unique_not_implemented},
};
return builtins;
+10 -10
View File
@@ -65,12 +65,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
COM_TRC("activated, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
COM_TRC("%s", "budget=0, forcing immediately\n");
}
}
break;
@@ -80,7 +80,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
{
if (ctx->end_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: deactivated (natural end)\n");
COM_TRC("%s", "deactivated (natural end)\n");
break;
}
@@ -95,7 +95,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
COM_TRC("%s", "UTF-8 complete, now forcing end sequence\n");
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
ctx->remaining--;
@@ -104,11 +104,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
COM_TRC("%s", "budget exhausted, forcing end sequence\n");
} else {
ctx->state = REASONING_BUDGET_WAITING_UTF8;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
COM_TRC("%s", "budget exhausted, waiting for UTF-8 completion\n");
}
}
}
@@ -118,7 +118,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
COM_TRC("%s", "forced sequence complete, done\n");
}
break;
case REASONING_BUDGET_DONE:
@@ -128,12 +128,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
COM_TRC("re-activated on new start tag, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
COM_TRC("%s", "budget=0, forcing immediately\n");
}
}
break;
@@ -264,7 +264,7 @@ bool common_reasoning_budget_force(struct llama_sampler * smpl) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n");
COM_TRC("%s", "forced into forcing state (manual transition)\n");
return true;
}
-204
View File
@@ -1,204 +0,0 @@
#include "regex-partial.h"
#include "common.h"
#include <functional>
#include <optional>
common_regex::common_regex(const std::string & pattern) :
pattern(pattern),
rx(pattern),
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
std::smatch match;
if (pos > input.size()) {
throw std::runtime_error("Position out of bounds");
}
auto start = input.begin() + pos;
auto found = as_match
? std::regex_match(start, input.end(), match, rx)
: std::regex_search(start, input.end(), match, rx);
if (found) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
for (size_t i = 0; i < match.size(); ++i) {
auto begin = pos + match.position(i);
res.groups.emplace_back(begin, begin + match.length(i));
}
return res;
}
std::match_results<std::string::const_reverse_iterator> srmatch;
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
auto group = srmatch[1].str();
if (group.length() != 0) {
auto it = srmatch[1].second.base();
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
if ((!as_match) || it == input.begin()) {
common_regex_match res;
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
const size_t begin = std::distance(input.begin(), it);
const size_t end = input.size();
if (begin == std::string::npos || end == std::string::npos || begin > end) {
throw std::runtime_error("Invalid range");
}
res.groups.push_back({begin, end});
return res;
}
}
}
return {};
}
/*
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
- /a|b/ -> ^(a|b)
- /a*?/ -> error, could match ""
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
- /.*?ab/ -> ^((?:b)?a) (omit .*)
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
*/
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto it = pattern.begin();
const auto end = pattern.end();
std::function<std::string()> process = [&]() {
std::vector<std::vector<std::string>> alternatives(1);
std::vector<std::string> * sequence = &alternatives.back();
while (it != end) {
if (*it == '[') {
auto start = it;
++it;
while (it != end) {
if ((*it == '\\') && (++it != end)) {
++it;
} else if ((it != end) && (*it == ']')) {
break;
} else {
++it;
}
}
if (it == end) {
throw std::runtime_error("Unmatched '[' in pattern");
}
++it;
sequence->push_back(std::string(start, it));
} else if (*it == '*' || *it == '?' || *it == '+') {
if (sequence->empty()) {
throw std::runtime_error("Quantifier without preceding element");
}
sequence->back() += *it;
auto is_star = *it == '*';
++it;
if (is_star) {
if (it != end && *it == '?') {
++it;
}
}
} else if (*it == '{') {
if (sequence->empty()) {
throw std::runtime_error("Repetition without preceding element");
}
++it;
auto start = it;
while (it != end && *it != '}') {
++it;
}
if (it == end) {
throw std::runtime_error("Unmatched '{' in pattern");
}
auto parts = string_split(std::string(start, it), ",");
++it;
if (parts.size() > 2) {
throw std::runtime_error("Invalid repetition range in pattern");
}
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
if (s.empty()) {
return def;
}
return std::stoi(s);
};
auto min = parseOptInt(parts[0], 0);
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
if (min && max && *max < *min) {
throw std::runtime_error("Invalid repetition range in pattern");
}
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
auto part = sequence->back();
sequence->pop_back();
for (int i = 0; i < *min; i++) {
sequence->push_back(part);
}
if (max) {
for (int i = *min; i < *max; i++) {
sequence->push_back(part + "?");
}
} else {
sequence->push_back(part + "*");
}
} else if (*it == '(') {
++it;
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
it += 2;
}
auto sub = process();
if (*it != ')') {
throw std::runtime_error("Unmatched '(' in pattern");
}
++it;
auto & part = sequence->emplace_back("(?:");
part += sub;
part += ")";
} else if (*it == ')') {
break;
} else if (*it == '|') {
++it;
alternatives.emplace_back();
sequence = &alternatives.back();
} else if (*it == '\\' && (++it != end)) {
auto str = std::string("\\") + *it;
sequence->push_back(str);
++it;
} else if (it != end) {
sequence->push_back(std::string(1, *it));
++it;
}
}
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
// We'll do the outermost capturing group and final .* in the enclosing function.
std::vector<std::string> res_alts;
for (const auto & parts : alternatives) {
auto & res = res_alts.emplace_back();
for (size_t i = 0; i < parts.size() - 1; i++) {
res += "(?:";
}
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
res += *it;
if (it != parts.rend() - 1) {
res += ")?";
}
}
}
return string_join(res_alts, "|");
};
auto res = process();
if (it != end) {
throw std::runtime_error("Unmatched '(' in pattern");
}
return "^(" + res + ")";
}
-56
View File
@@ -1,56 +0,0 @@
#pragma once
#include <regex>
#include <string>
enum common_regex_match_type {
COMMON_REGEX_MATCH_TYPE_NONE,
COMMON_REGEX_MATCH_TYPE_PARTIAL,
COMMON_REGEX_MATCH_TYPE_FULL,
};
struct common_string_range {
size_t begin;
size_t end;
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
if (begin > end) {
throw std::runtime_error("Invalid range");
}
}
// prevent default ctor
common_string_range() = delete;
bool empty() const {
return begin == end;
}
bool operator==(const common_string_range & other) const {
return begin == other.begin && end == other.end;
}
};
struct common_regex_match {
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
std::vector<common_string_range> groups;
bool operator==(const common_regex_match & other) const {
return type == other.type && groups == other.groups;
}
bool operator!=(const common_regex_match & other) const {
return !(*this == other);
}
};
class common_regex {
std::string pattern;
std::regex rx;
std::regex rx_reversed_partial;
public:
explicit common_regex(const std::string & pattern);
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
const std::string & str() const { return pattern; }
};
// For testing only (pretty print of failures).
std::string regex_to_reversed_partial_regex(const std::string & pattern);
+371 -65
View File
@@ -18,6 +18,13 @@
#include <map>
#include <cinttypes>
#define SPC_DBG(fmt, ...) LOG_DBG("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_TRC(fmt, ...) LOG_TRC("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_INF(fmt, ...) LOG_INF("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_WRN(fmt, ...) LOG_WRN("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_ERR(fmt, ...) LOG_ERR("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SPC_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@@ -26,6 +33,7 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
{"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP},
{"draft-dflash", COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH},
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -60,21 +68,20 @@ static bool common_speculative_are_compatible(
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
SPC_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
SPC_DBG("vocab_type dft: %d\n", vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LOG_WRN("%s: draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
SPC_WRN("draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return false;
}
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
__func__,
SPC_WRN("draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
return false;
@@ -82,8 +89,7 @@ static bool common_speculative_are_compatible(
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
__func__,
SPC_WRN("draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
return false;
@@ -97,8 +103,8 @@ static bool common_speculative_are_compatible(
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
SPC_DBG("draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}
@@ -108,8 +114,8 @@ static bool common_speculative_are_compatible(
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
SPC_DBG("draft model vocab must match target model to use speculation but "
"token %d content differs - target '%s', draft '%s'\n", i,
common_token_to_piece(vocab_tgt, i).c_str(),
common_token_to_piece(vocab_dft, i).c_str());
return false;
@@ -186,9 +192,9 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
auto * ctx_tgt = this->params.ctx_tgt;
LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'draft-simple'\n");
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f\n", this->params.n_max, this->params.n_min, this->params.p_min);
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
ggml_type_name(this->params.cache_type_v),
@@ -228,16 +234,16 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
}
const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt);
SPC_DBG("vocab_cmpt = %d\n", vocab_cmpt);
if (!vocab_cmpt) {
LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__);
SPC_ERR("%s", "the target and draft vocabs are not compatible\n");
throw std::runtime_error("draft model vocab type must match target model to use speculation");
}
if (n_seq != llama_n_seq_max(ctx_dft)) {
LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft));
SPC_ERR("n_seq mismatch: %d != %d\n", n_seq, llama_n_seq_max(ctx_dft));
throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq");
}
@@ -257,7 +263,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
const int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret);
SPC_ERR("failed to decode draft batch, ret = %d\n", ret);
return false;
}
@@ -290,7 +296,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
SPC_ERR("llama_decode returned %d\n", ret);
return;
}
@@ -314,7 +320,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
@@ -354,7 +360,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
break;
}
@@ -449,8 +455,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
, params(params.draft)
{
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
SPC_TRC("%s", "adding speculative implementation 'draft-eagle3'\n");
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
@@ -493,7 +499,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
llama_sampler_free(chain);
chain = nullptr;
}
@@ -548,9 +554,9 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 2) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
SPC_WRN("ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 2);
(int) pos_max, N - 2);
}
}
@@ -621,8 +627,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
};
const int32_t rc = llama_encode(ctx_dft, enc_batch);
if (rc != 0) {
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
__func__, rc, (int) n_chunk, (int) i);
SPC_ERR("llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
rc, (int) n_chunk, (int) i);
return false;
}
@@ -692,8 +698,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
if (batch.n_tokens > 0) {
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
__func__, rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
SPC_ERR("llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
return false;
}
}
@@ -744,7 +750,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
SPC_ERR("llama_decode returned %d\n", ret);
return;
}
@@ -770,7 +776,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
@@ -809,7 +815,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
break;
}
@@ -893,6 +899,296 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
}
};
// DFlash: block-diffusion drafting with a draft-side KV cache injection
struct common_speculative_impl_draft_dflash : public common_speculative_impl {
common_params_speculative_draft params;
llama_batch batch; // noise tokens
llama_batch batch_inject; // target features for KV cache injection
std::vector<common_sampler_ptr> smpls;
int32_t n_embd_dec = 0; // draft hidden size
int32_t n_embd_enc = 0; // target_layer_ids_n * target_hidden_size
int32_t n_embd_tgt = 0; // target model hidden size
int32_t block_size = 0;
llama_token mask_token_id = 0;
const int32_t * target_layer_ids = nullptr; // model_dft's extract layer indices
uint32_t target_layer_ids_n = 0;
// scratch buffer for concatenated target features [n_tokens, n_embd_enc]
std::vector<float> features_buf;
common_speculative_impl_draft_dflash(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, n_seq)
, params(params.draft)
{
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "DFlash requires ctx_tgt and ctx_dft to be set");
const llama_model * model_dft = llama_get_model(ctx_dft);
const llama_model * model_tgt = llama_get_model(ctx_tgt);
target_layer_ids = llama_model_target_layer_ids (model_dft);
target_layer_ids_n = llama_model_target_layer_ids_n(model_dft);
GGML_ASSERT(target_layer_ids_n > 0 && "DFlash model has no target_layer_ids");
n_embd_tgt = llama_model_n_embd(model_tgt);
n_embd_dec = llama_model_n_embd(model_dft);
n_embd_enc = (int32_t) target_layer_ids_n * n_embd_tgt;
// read the trained block size from the dflash.block_size metadata key
block_size = 16;
{
char buf[32] = {};
if (llama_model_meta_val_str(model_dft, "dflash.block_size", buf, sizeof(buf)) >= 0) {
block_size = std::atoi(buf);
}
}
mask_token_id = llama_vocab_mask(llama_model_get_vocab(model_dft));
LOG_INF("%s: adding speculative implementation 'draft-dflash'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
LOG_INF("%s: - block_size=%d, mask_token_id=%d, n_extract=%u\n", __func__, block_size, mask_token_id, target_layer_ids_n);
// DFlash input is [id_last, <mask> * (block_size-1)], so it can draft at most block_size-1 tokens per step
if (this->params.n_max > block_size - 1) {
LOG_WRN("%s: requested draft size %d exceeds the trained DFlash block size %d -- clamping to %d draft tokens per step\n",
__func__, this->params.n_max, block_size - 1, block_size - 1);
this->params.n_max = block_size - 1;
}
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, n_seq);
batch_inject = llama_batch_init(llama_n_batch(ctx_dft), n_embd_dec, n_seq);
smpls.resize(n_seq);
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 1;
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(model_dft, sparams));
}
// turn on extraction of the target layers' input embeddings
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
llama_set_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k], true);
}
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
llama_set_causal_attn(ctx_dft, false); // DFlash needs non-causal attention
}
~common_speculative_impl_draft_dflash() override {
llama_batch_free(batch);
llama_batch_free(batch_inject);
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
return;
}
const int32_t N = (int32_t) prompt.size();
if (N <= 0) {
return;
}
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(params.ctx_dft), seq_id);
if (pos_max < N - 1) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - process() did not run on every prefill ubatch. "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 1);
}
}
bool process(const llama_batch & batch_in) override {
if (batch_in.n_tokens <= 0) {
return true;
}
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
return true;
}
const int32_t n_tokens = batch_in.n_tokens;
// per-seq inclusive batch range (assumes each seq's tokens are contiguous in the batch)
std::vector<int32_t> i_batch_beg(n_seq, -1);
std::vector<int32_t> i_batch_end(n_seq, -1);
for (int32_t k = 0; k < n_tokens; ++k) {
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
const llama_seq_id seq_id = batch_in.seq_id[k][0];
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
continue;
}
i_batch_end[seq_id] = k;
if (i_batch_beg[seq_id] < 0) {
i_batch_beg[seq_id] = k;
}
}
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
const int32_t n_ubatch = (int32_t) llama_n_ubatch(ctx_dft);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1;
for (int32_t offset = 0; offset < n_rows; offset += n_ubatch) {
const int32_t n_chunk = std::min(n_ubatch, n_rows - offset);
// gather this chunk's target features, interleaved by extract layer
features_buf.resize((size_t) n_chunk * n_embd_enc);
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
const float * layer = llama_get_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k]);
if (!layer) {
GGML_ABORT("DFlash: target layer %d input not extracted.", target_layer_ids[k]);
}
for (int32_t i = 0; i < n_chunk; ++i) {
float * dst = features_buf.data() + (size_t) i * n_embd_enc + k * (size_t) n_embd_tgt;
const float * src = layer + (size_t) (i_batch_beg[seq_id] + offset + i) * n_embd_tgt;
std::memcpy(dst, src, (size_t) n_embd_tgt * sizeof(float));
}
}
// fuse extracted features through DFlash encoder
llama_batch enc_batch = {
/*.n_tokens =*/ n_chunk,
/*.token =*/ nullptr,
/*.embd =*/ features_buf.data(),
/*.pos =*/ nullptr,
/*.n_seq_id =*/ nullptr,
/*.seq_id =*/ nullptr,
/*.logits =*/ nullptr,
};
int32_t rc = llama_encode(ctx_dft, enc_batch);
if (rc != 0) {
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
__func__, rc, (int) n_chunk, (int) offset);
return false;
}
const float * inp_g = llama_get_embeddings_nextn(ctx_dft);
GGML_ASSERT(inp_g && "DFlash encoder produced no output.");
// inject the DFlash decoder K/V cache at the tokens' target positions
batch_inject.n_tokens = n_chunk;
std::memcpy(batch_inject.embd, inp_g, (size_t) n_chunk * n_embd_dec * sizeof(float));
for (int32_t i = 0; i < n_chunk; ++i) {
batch_inject.pos[i] = batch_in.pos[i_batch_beg[seq_id] + offset + i];
batch_inject.n_seq_id[i] = 1;
batch_inject.seq_id[i][0] = seq_id;
batch_inject.logits[i] = false;
}
rc = llama_decode(ctx_dft, batch_inject);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
__func__, rc, (int) n_chunk, (int) offset);
return false;
}
}
}
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
auto & ctx_dft = params.ctx_dft;
common_batch_clear(batch);
// build one batch holding every drafting sequence's noise block into a single decode)
// record where each block starts and its size
std::vector<int32_t> i_block_beg(n_seq, -1);
std::vector<int32_t> n_block (n_seq, 0);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
common_sampler_reset(smpls[seq_id].get());
const int32_t n = (int32_t) dp.n_past;
int32_t n_draft = params.n_max;
if (dp.n_max > 0) {
n_draft = std::min(n_draft, dp.n_max);
}
const int32_t n_block_tokens = n_draft + 1; // id_last + n_draft * <mask>
i_block_beg[seq_id] = batch.n_tokens;
n_block [seq_id] = n_block_tokens;
for (int32_t i = 0; i < n_block_tokens; ++i) {
common_batch_add(batch, i == 0 ? dp.id_last : mask_token_id, n + i, { seq_id }, true);
}
}
if (batch.n_tokens == 0) {
return;
}
// decode all sequence's noise block in a single batch
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_block_beg[seq_id] < 0) {
continue;
}
auto & dp = dparams[seq_id];
const int32_t beg = i_block_beg[seq_id];
const int32_t n_block_tokens = n_block[seq_id];
auto * smpl = smpls[seq_id].get();
auto & result = *dp.result;
// greedily read the predicted block at this sequence's noise positions 1..n_block_tokens-1
for (int32_t i = 1; i < n_block_tokens; ++i) {
common_sampler_sample(smpl, ctx_dft, beg + i, true);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i - 1, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
const llama_token id = cur_p->data[0].id;
common_sampler_accept(smpl, id, true);
result.push_back(id);
}
}
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
// noop
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
@@ -942,9 +1238,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
"MTP input row width must match the target h_nextn width");
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'draft-mtp'\n");
SPC_TRC("- n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
ggml_type_name(this->params.cache_type_v),
@@ -975,7 +1271,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
llama_sampler_free(chain);
chain = nullptr;
}
@@ -1038,11 +1334,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1 && !is_mem_shared) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
SPC_WRN("ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 1);
(int) pos_max, N - 1);
}
}
@@ -1128,8 +1424,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
__func__, head, (int) rc, (int) batch_in.pos[0]);
SPC_ERR("llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
head, (int) rc, (int) batch_in.pos[0]);
ok = false;
break;
}
@@ -1217,7 +1513,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
break;
}
@@ -1239,7 +1535,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
@@ -1353,8 +1649,8 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
, params(params.ngram_simple)
, config(config)
{
LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__);
LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'ngram-simple'\n");
SPC_TRC("- size_n=%d, size_m=%d, min_hits=%d\n",
this->params.size_n, this->params.size_m, this->params.min_hits);
}
@@ -1403,8 +1699,8 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
this->config.push_back(config);
}
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str());
LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__,
SPC_TRC("adding speculative implementation '%s'\n", common_speculative_type_to_str(this->type).c_str());
SPC_TRC("- size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n",
config.size_key, config.size_value, config.key_only, config.min_hits);
}
@@ -1478,15 +1774,15 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__);
LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'ngram-mod'\n");
SPC_TRC("- n_match=%d, n_max=%d, n_min=%d\n",
this->params.n_match, this->params.n_max, this->params.n_min);
LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__,
SPC_TRC("- mod size=%zu (%.3f MB)\n",
mod.size(), (float)(mod.size_bytes())/1024/1024);
if (this->params.n_match < 16) {
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match);
SPC_WRN("ngram_mod n_match=%d is too small - poor quality is possible, "
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", this->params.n_match);
}
sinfos.resize(n_seq);
@@ -1510,11 +1806,11 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
sinfo.i_last = prompt.size() - n;
const double f = (double)mod.get_used() / (double)mod.size();
LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f);
SPC_TRC("ngram_mod occupancy = %zu/%zu (%.2f)\n", mod.get_used(), mod.size(), f);
constexpr double f_thold = 0.25;
if (f > f_thold) {
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
SPC_WRN("ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", f, f_thold);
mod.reset();
}
@@ -1608,7 +1904,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
sinfo.n_low++;
if (sinfo.n_low >= 5) {
if (verbose) {
LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low);
SPC_TRC("low acceptance streak (%d) - resetting ngram_mod\n", sinfo.n_low);
}
mod.reset();
@@ -1658,8 +1954,8 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
, save_dynamic(save_dynamic)
, save_static(save_static)
{
LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__);
LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__,
SPC_TRC("%s", "adding speculative implementation 'ngram-cache'\n");
SPC_TRC("- n_draft=%d, cache_static=%s, cache_dynamic=%s\n",
n_draft,
path_static.empty() ? "none" : path_static.c_str(),
path_dynamic.empty() ? "none" : path_dynamic.c_str());
@@ -1674,7 +1970,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
sinfo.ngram_cache_static = ngram_cache_static;
}
} catch (...) {
LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
SPC_ERR("failed to open static lookup cache: %s", path_static.c_str());
GGML_ABORT("Couldn't read static lookup cache");
}
}
@@ -1687,7 +1983,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
sinfo.ngram_cache_dynamic = ngram_cache_dynamic;
}
} catch (...) {
LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
SPC_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
GGML_ABORT("Couldn't read dynamic lookup cache");
}
}
@@ -1836,6 +2132,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp";
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH: return "draft-dflash";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
@@ -1888,6 +2185,7 @@ int32_t common_speculative_n_max(const common_params_speculative * spec) {
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH:
n_max = std::max(n_max, std::max(0, spec->draft.n_max));
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
@@ -1925,6 +2223,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
bool has_draft_dflash = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH)) && params.draft.ctx_dft != nullptr;
@@ -1935,7 +2234,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD));
// when adding a new type - update here the logic above
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9);
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 10);
// this list here defines the priority of the speculators
// the one with highest priority are listed first
@@ -1965,6 +2264,9 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
}
if (has_draft_dflash) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, params));
}
}
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
@@ -1985,6 +2287,10 @@ common_speculative * common_speculative_init(common_params_speculative & params,
impls.push_back(std::make_unique<common_speculative_impl_draft_mtp>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH: {
impls.push_back(std::make_unique<common_speculative_impl_draft_dflash>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
@@ -2034,7 +2340,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
}
if (impls.empty()) {
LOG_WRN("%s: no implementations specified for speculative decoding\n", __func__);
SPC_TRC("%s", "no implementations specified for speculative decoding\n");
return nullptr;
}
@@ -2161,13 +2467,13 @@ void common_speculative_draft(common_speculative * spec) {
if (dp.n_max > 0) {
if (!result.empty() && (int) result.size() > dp.n_max) {
LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max);
SPC_DBG("truncating draft to %d tokens\n", dp.n_max);
result.resize(dp.n_max);
}
}
if (!result.empty()) {
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
SPC_DBG("called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n",
common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(),
impl.get()->n_call_draft, result.size());
@@ -2291,7 +2597,7 @@ void common_speculative_print_stats(const common_speculative * spec) {
str_stats = ", #mean acc len = " + oss.str() + ", #acc rate/pos = (" + tmp.str() + ")";
}
LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
SPC_TRC("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
common_speculative_type_to_str(impl->type).c_str(),
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
impl->n_gen_drafts,
+2
View File
@@ -50,6 +50,8 @@ TEXT_MODEL_MAP: dict[str, str] = {
"DeepseekV2ForCausalLM": "deepseek",
"DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek",
"DFlashDraftModel": "qwen",
"DeepseekV4ForCausalLM": "deepseek",
"DistilBertForMaskedLM": "bert",
"DistilBertForSequenceClassification": "bert",
"DistilBertModel": "bert",
+14 -1
View File
@@ -1273,7 +1273,7 @@ class TextModel(ModelBase):
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None:
if (n_experts := self.find_hparam(["num_local_experts", "num_experts", "n_routed_experts"], optional=True)) is not None:
self.gguf_writer.add_expert_count(n_experts)
logger.info(f"gguf: expert count = {n_experts}")
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token", "top_k_experts"], optional=True)) is not None:
@@ -1291,6 +1291,8 @@ class TextModel(ModelBase):
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
elif score_func == "sqrtsoftplus":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SQRTSOFTPLUS)
else:
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
logger.info(f"gguf: expert score gating function = {score_func}")
@@ -2600,6 +2602,17 @@ class LazyTorchTensor(gguf.LazyBase):
return cls._wrap_fn(func)(*args, **kwargs)
if hasattr(torch, "float8_e8m0fnu"):
_torch_float8_e8m0 = torch.float8_e8m0fnu
LazyTorchTensor._dtype_map[_torch_float8_e8m0] = np.uint8
LazyTorchTensor._dtype_byteswap_map[_torch_float8_e8m0] = np.uint8
LazyTorchTensor._dtype_str_map["F8_E8M0"] = _torch_float8_e8m0
else:
# Older torch builds do not expose F8_E8M0. Keep the raw bytes so callers
# that know the format can decode them explicitly.
LazyTorchTensor._dtype_str_map["F8_E8M0"] = torch.uint8
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
# maybe we should fallback to text model's arch in that case, since not many models have both
+308 -1
View File
@@ -1,15 +1,18 @@
from __future__ import annotations
import json
import re
from pathlib import Path
from typing import Any, Callable, Iterable, TYPE_CHECKING
import numpy as np
import torch
if TYPE_CHECKING:
from torch import Tensor
from .base import MmprojModel, ModelBase, TextModel, gguf, logger
from .base import LazyTorchTensor, MmprojModel, ModelBase, TextModel, gguf, logger
from .qwen import QwenModel
@@ -467,3 +470,307 @@ class DeepseekV32Model(DeepseekV2Model):
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
@ModelBase.register("DeepseekV4ForCausalLM")
class DeepseekV4Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK4
_skipped_mtp_tensors = 0
def __init__(self, *args, **kwargs):
type(self)._skipped_mtp_tensors = 0
super().__init__(*args, **kwargs)
with open(self.dir_model / "config.json", "r", encoding="utf-8") as f:
raw_hparams = json.load(f)
for key, value in raw_hparams.items():
self.hparams.setdefault(key, value)
self.block_count = self.hparams["num_hidden_layers"]
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
self._dsv4_fp8_dequantized: set[str] = set()
self._dsv4_bf16_tensors: set[str] = set()
self._dsv4_f32_tensors: set[str] = set()
self._dsv4_mxfp4_generated = False
self._collect_source_dtypes()
if type(self)._skipped_mtp_tensors:
logger.info("Skipping %d DeepSeek-V4 MTP tensor(s) for conversion v0", type(self)._skipped_mtp_tensors)
# add a default chat template; if the model has a built-in template, it will be overridden later
template_path = Path(__file__).parent.parent / "models" / "templates" / "deepseek-ai-DeepSeek-V4.jinja"
if template_path.is_file():
with open(template_path, "r", encoding="utf-8") as f:
self.gguf_writer.add_chat_template(f.read())
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, _ = item
if name.startswith("mtp."):
cls._skipped_mtp_tensors += 1
return None
return super().filter_tensors(item)
@staticmethod
def _float8_dtypes() -> tuple[torch.dtype, ...]:
return tuple(
dtype for dtype in (
getattr(torch, "float8_e4m3fn", None),
getattr(torch, "float8_e5m2", None),
) if dtype is not None
)
@staticmethod
def _e8m0_to_float(scale: Tensor) -> Tensor:
torch_float8_e8m0 = getattr(torch, "float8_e8m0fnu", None)
if torch_float8_e8m0 is not None and scale.dtype == torch_float8_e8m0:
return scale.float()
bits = scale.view(torch.uint8).float()
return torch.exp2(bits - 127.0)
def _collect_source_dtypes(self) -> None:
for name, gen in self.model_tensors.items():
dtype = gen().dtype
if dtype == torch.bfloat16:
self._dsv4_bf16_tensors.add(name)
elif dtype == torch.float32:
self._dsv4_f32_tensors.add(name)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
self.gguf_writer.add_swiglu_clamp_exp([hparams["swiglu_limit"]] * self.block_count)
self.gguf_writer.add_swiglu_clamp_shexp([hparams["swiglu_limit"]] * self.block_count)
self.gguf_writer.add_indexer_head_count(hparams["index_n_heads"])
self.gguf_writer.add_indexer_key_length(hparams["index_head_dim"])
self.gguf_writer.add_indexer_top_k(hparams["index_topk"])
self.gguf_writer.add_attention_output_group_count(hparams["o_groups"])
self.gguf_writer.add_attention_output_lora_rank(hparams["o_lora_rank"])
self.gguf_writer.add_attention_compress_ratios(hparams["compress_ratios"])
self.gguf_writer.add_attention_compress_rope_freq_base(hparams["compress_rope_theta"])
self.gguf_writer.add_hyper_connection_count(hparams["hc_mult"])
self.gguf_writer.add_hyper_connection_sinkhorn_iterations(hparams["hc_sinkhorn_iters"])
self.gguf_writer.add_hyper_connection_epsilon(hparams["hc_eps"])
self.gguf_writer.add_hash_layer_count(hparams["num_hash_layers"])
def dequant_model(self):
fp8_dtypes = self._float8_dtypes()
tensors_to_remove: list[str] = []
def dequant_fp8_weight(weight: Tensor, scale: Tensor) -> Tensor:
out_features, in_features = weight.shape
scale_f = self._e8m0_to_float(scale)
scale_f = scale_f.repeat_interleave(128, 0)[:out_features]
scale_f = scale_f.repeat_interleave(128, 1)[:, :in_features]
return weight.float() * scale_f
for name in list(self.model_tensors.keys()):
if not name.endswith(".scale"):
continue
weight_name = name.removesuffix(".scale") + ".weight"
if weight_name not in self.model_tensors:
continue
weight = self.model_tensors[weight_name]
scale = self.model_tensors[name]
if weight().dtype not in fp8_dtypes:
continue
self.model_tensors[weight_name] = lambda w=weight, s=scale: dequant_fp8_weight(w(), s())
self._dsv4_fp8_dequantized.add(weight_name)
tensors_to_remove.append(name)
for name in tensors_to_remove:
del self.model_tensors[name]
@staticmethod
def _pack_mxfp4_blocks(weight: Tensor, scale: Tensor) -> np.ndarray:
packed = weight.contiguous().view(torch.uint8)
scale_u8 = scale.contiguous().view(torch.uint8)
out_features, packed_cols = packed.shape
logical_cols = packed_cols * 2
if logical_cols % 32 != 0:
raise ValueError(f"MXFP4 source row has {logical_cols} values, expected a multiple of 32")
n_blocks = logical_cols // 32
if tuple(scale_u8.shape) != (out_features, n_blocks):
raise ValueError(f"MXFP4 scale shape {tuple(scale_u8.shape)} does not match {(out_features, n_blocks)}")
src = packed.reshape(out_features, n_blocks, 16)
low = src & 0x0F
high = (src >> 4) & 0x0F
# The safetensors bytes store adjacent values as low/high nibbles.
# ggml MXFP4 blocks store values 0..15 in low nibbles and 16..31 in high nibbles.
vals = torch.stack((low, high), dim=-1).reshape(out_features, n_blocks, 32)
qs = vals[:, :, :16] | (vals[:, :, 16:] << 4)
raw = torch.cat((scale_u8.unsqueeze(-1), qs.to(torch.uint8)), dim=-1)
return raw.reshape(out_features, n_blocks * 17).cpu().numpy()
def _write_mxfp4_expert_tensor(self, bid: int, proj: str, tensor_key: gguf.MODEL_TENSOR) -> list[str]:
n_experts = self.hparams["n_routed_experts"]
data: np.ndarray | None = None
consumed: list[str] = []
for eid in range(n_experts):
weight_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.weight"
scale_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.scale"
if weight_name not in self.model_tensors or scale_name not in self.model_tensors:
raise KeyError(f"Missing routed expert tensors for {weight_name}")
weight = LazyTorchTensor.to_eager(self.model_tensors[weight_name]())
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
packed = self._pack_mxfp4_blocks(weight, scale)
if data is None:
data = np.empty((n_experts, *packed.shape), dtype=packed.dtype)
data[eid] = packed
consumed.extend((weight_name, scale_name))
assert data is not None
new_name = self.format_tensor_name(tensor_key, bid)
shape = gguf.quant_shape_from_byte_shape(data.shape, gguf.GGMLQuantizationType.MXFP4)
logger.info(f"{new_name}: repacked routed experts to MXFP4, shape = {{{', '.join(str(n) for n in reversed(shape))}}}")
self.gguf_writer.add_tensor(new_name, data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
return consumed
def _write_hash_routing_tensors(self) -> list[str]:
consumed: list[str] = []
for bid in range(self.hparams["num_hash_layers"]):
name = f"layers.{bid}.ffn.gate.tid2eid"
if name not in self.model_tensors:
raise KeyError(f"Missing hash routing tensor {name}")
data_torch = LazyTorchTensor.to_eager(self.model_tensors[name]())
data = data_torch.to(torch.int32).cpu().numpy()
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_TID2EID, bid, ".weight")
logger.info(f"{new_name}: converted hash routing table to I32, shape = {{{', '.join(str(n) for n in reversed(data.shape))}}}")
self.gguf_writer.add_tensor(new_name, data)
consumed.append(name)
return consumed
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
if self._dsv4_mxfp4_generated:
return ()
consumed: list[str] = self._write_hash_routing_tensors()
for bid in range(self.block_count):
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w1", gguf.MODEL_TENSOR.FFN_GATE_EXP))
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP))
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w3", gguf.MODEL_TENSOR.FFN_UP_EXP))
for name in consumed:
del self.model_tensors[name]
self._dsv4_mxfp4_generated = True
return ()
def _format_dsv4_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> str:
return self.format_tensor_name(key, bid, suffix)
def _map_dsv4_tensor_name(self, name: str, bid: int | None) -> tuple[gguf.MODEL_TENSOR, str]:
root_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
"embed.weight": (gguf.MODEL_TENSOR.TOKEN_EMBD, ".weight"),
"norm.weight": (gguf.MODEL_TENSOR.OUTPUT_NORM, ".weight"),
"head.weight": (gguf.MODEL_TENSOR.OUTPUT, ".weight"),
"hc_head_fn": (gguf.MODEL_TENSOR.HC_HEAD_FN, ".weight"),
"hc_head_base": (gguf.MODEL_TENSOR.HC_HEAD_BASE, ".weight"),
"hc_head_scale": (gguf.MODEL_TENSOR.HC_HEAD_SCALE, ".weight"),
}
if name in root_map:
return root_map[name]
match = re.match(r"layers\.(\d+)\.(.+)$", name)
if match is None:
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
layer = int(match.group(1))
if bid != layer:
raise ValueError(f"Tensor {name!r} parsed bid {bid} but layer name has {layer}")
layer_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
"hc_attn_fn": (gguf.MODEL_TENSOR.HC_ATTN_FN, ".weight"),
"hc_attn_base": (gguf.MODEL_TENSOR.HC_ATTN_BASE, ".weight"),
"hc_attn_scale": (gguf.MODEL_TENSOR.HC_ATTN_SCALE, ".weight"),
"hc_ffn_fn": (gguf.MODEL_TENSOR.HC_FFN_FN, ".weight"),
"hc_ffn_base": (gguf.MODEL_TENSOR.HC_FFN_BASE, ".weight"),
"hc_ffn_scale": (gguf.MODEL_TENSOR.HC_FFN_SCALE, ".weight"),
"attn.attn_sink": (gguf.MODEL_TENSOR.ATTN_SINKS, ".weight"),
"attn.wq_a.weight": (gguf.MODEL_TENSOR.ATTN_Q_A, ".weight"),
"attn.wq_b.weight": (gguf.MODEL_TENSOR.ATTN_Q_B, ".weight"),
"attn.q_norm.weight": (gguf.MODEL_TENSOR.ATTN_Q_A_NORM, ".weight"),
"attn.wkv.weight": (gguf.MODEL_TENSOR.ATTN_KV, ".weight"),
"attn.kv_norm.weight": (gguf.MODEL_TENSOR.ATTN_KV_NORM, ".weight"),
"attn.wo_a.weight": (gguf.MODEL_TENSOR.ATTN_OUT_A, ".weight"),
"attn.wo_b.weight": (gguf.MODEL_TENSOR.ATTN_OUT_B, ".weight"),
"attn.compressor.ape": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_APE, ".weight"),
"attn.compressor.wkv.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WKV, ".weight"),
"attn.compressor.wgate.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WGATE, ".weight"),
"attn.compressor.norm.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_NORM, ".weight"),
"attn.indexer.wq_b.weight": (gguf.MODEL_TENSOR.INDEXER_ATTN_Q_B, ".weight"),
"attn.indexer.weights_proj.weight": (gguf.MODEL_TENSOR.INDEXER_PROJ, ".weight"),
"attn.indexer.compressor.ape": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_APE, ".weight"),
"attn.indexer.compressor.wkv.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WKV, ".weight"),
"attn.indexer.compressor.wgate.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE, ".weight"),
"attn.indexer.compressor.norm.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_NORM, ".weight"),
"attn_norm.weight": (gguf.MODEL_TENSOR.ATTN_NORM, ".weight"),
"ffn_norm.weight": (gguf.MODEL_TENSOR.FFN_NORM, ".weight"),
"ffn.gate.weight": (gguf.MODEL_TENSOR.FFN_GATE_INP, ".weight"),
"ffn.gate.bias": (gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, ".bias"),
"ffn.gate.tid2eid": (gguf.MODEL_TENSOR.FFN_GATE_TID2EID, ".weight"),
"ffn.shared_experts.w1.weight": (gguf.MODEL_TENSOR.FFN_GATE_SHEXP, ".weight"),
"ffn.shared_experts.w2.weight": (gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, ".weight"),
"ffn.shared_experts.w3.weight": (gguf.MODEL_TENSOR.FFN_UP_SHEXP, ".weight"),
}
tensor_name = match.group(2)
if tensor_name in layer_map:
return layer_map[tensor_name]
if re.match(r"ffn\.experts\.\d+\.w[123]\.(weight|scale)$", tensor_name):
return gguf.MODEL_TENSOR.FFN_GATE_EXP, ".weight"
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if re.match(r"layers\.\d+\.ffn\.experts\.\d+\.w[123]\.(weight|scale)$", name):
return []
tensor_key, suffix = self._map_dsv4_tensor_name(name, bid)
if tensor_key == gguf.MODEL_TENSOR.FFN_GATE_TID2EID:
return []
return [(self._format_dsv4_tensor_name(tensor_key, bid, suffix), data_torch)]
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del new_name, bid # unused
if name in self._dsv4_fp8_dequantized and n_dims >= 2:
return gguf.GGMLQuantizationType.Q8_0
if name in self._dsv4_f32_tensors:
return gguf.GGMLQuantizationType.F32
if name in self._dsv4_bf16_tensors and n_dims >= 2:
return gguf.GGMLQuantizationType.BF16
return False
def prepare_tensors(self):
super().prepare_tensors()
self._is_mxfp4 = True
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE
+3 -3
View File
@@ -73,7 +73,7 @@ class LlamaModel(TextModel):
target_num_layers = target_config["num_hidden_layers"]
target_layers = [2, target_num_layers // 2, target_num_layers - 3]
logger.info(f"EAGLE-3: target_layers = {target_layers} (target model has {target_num_layers} layers)")
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", target_layers)
self.gguf_writer.add_target_layers(target_layers)
# target_hidden_size: prefer eagle3 config, fallback to target config
if eagle3_raw_config.get("target_hidden_size") is not None:
@@ -83,12 +83,12 @@ class LlamaModel(TextModel):
target_hidden_size = target_config["hidden_size"]
src = "target model config"
logger.info(f"EAGLE-3: target_hidden_size = {target_hidden_size} (from {src})")
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size)
self.gguf_writer.add_target_hidden_size(target_hidden_size)
# norm_before_residual (RedHat-style eagle3 specific)
norm_before_residual = eagle3_raw_config.get("norm_before_residual", False)
logger.info(f"EAGLE-3: norm_before_residual = {norm_before_residual}")
self.gguf_writer.add_bool(f"{self.gguf_writer.arch}.norm_before_residual", norm_before_residual)
self.gguf_writer.add_norm_before_residual(norm_before_residual)
def set_vocab(self):
# eagle3: use tokenizer from target model if provided
+3 -4
View File
@@ -114,7 +114,8 @@ class Mamba2Model(TextModel):
hparams["text_config"] = hparams["llm_config"]
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
self.expand = self.find_hparam(["mamba_expand", "expand"], optional=True) or 2
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or self.expand * self.d_model
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1
def set_vocab(self):
@@ -144,11 +145,9 @@ class Mamba2Model(TextModel):
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
# skip the assertion for FalconH1 Model
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
assert self.d_inner == 2 * self.d_model
assert self.d_inner == self.expand * self.d_model
assert self.d_inner % head_dim == 0
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
+48
View File
@@ -625,3 +625,51 @@ class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReor
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35MOE
@ModelBase.register("DFlashDraftModel")
class DFlashModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.DFLASH
def set_vocab(self):
if self.target_model_dir is None:
raise ValueError(
"DFlash draft model requires --target-model-dir to be specified. "
"Please provide the path to the target model directory containing the tokenizer."
)
logger.info(f"DFlash: Using tokenizer from target model: {self.target_model_dir}")
original_dir = self.dir_model
self.dir_model = self.target_model_dir
super().set_vocab()
self.dir_model = original_dir
mask_token_id = self.hparams.get("dflash_config", {}).get("mask_token_id")
if mask_token_id is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
def set_gguf_parameters(self):
super().set_gguf_parameters()
block_size = self.hparams.get("block_size", 16)
self.gguf_writer.add_block_size(block_size)
dflash_config = self.hparams.get("dflash_config", {})
target_layer_ids = dflash_config.get("target_layer_ids", [])
if target_layer_ids:
extract_layer_ids = [i + 1 for i in target_layer_ids]
self.gguf_writer.add_target_layers(extract_layer_ids)
use_sliding_window = self.hparams.get("use_sliding_window", False)
sliding_window = self.hparams.get("sliding_window")
layer_types = self.hparams.get("layer_types")
if use_sliding_window and sliding_window and layer_types:
is_swa = [lt == "sliding_attention" for lt in layer_types]
self.gguf_writer.add_sliding_window(sliding_window)
self.gguf_writer.add_sliding_window_pattern(is_swa)
@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item
if not name.startswith("model."):
name = "model." + name
return super().filter_tensors((name, gen))
+6 -6
View File
@@ -237,8 +237,8 @@ chmod +x ubuntu-llamacpp-ov-install.sh
# ============================================
set -euo pipefail
OPENVINO_VERSION_MAJOR="2026.2"
OPENVINO_VERSION_FULL="2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR="2026.2.1"
OPENVINO_VERSION_FULL="2026.2.1.21919.ede283a88e3"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OPENVINO_INSTALL_DIR="/opt/intel/openvino_${OPENVINO_VERSION_MAJOR}"
@@ -334,7 +334,7 @@ echo " ./build/ReleaseOV/bin/llama-cli -m model.gguf"
```
> [!NOTE]
> The script pins OpenVINO `2026.2` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release.
> The script pins OpenVINO `2026.2.1` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release.
</details>
@@ -364,8 +364,8 @@ REM ============================================
REM llama.cpp OpenVINO Build Script (Ninja)
REM ============================================
set "OPENVINO_VERSION_MAJOR=2026.2"
set "OPENVINO_VERSION_FULL=2026.2.0.21903.52ddc073857"
set "OPENVINO_VERSION_MAJOR=2026.2.1"
set "OPENVINO_VERSION_FULL=2026.2.1.21919.ede283a88e3"
set "SCRIPT_DIR=%~dp0"
set "VCPKG_DIR=C:\vcpkg"
@@ -547,7 +547,7 @@ endlocal
```
> [!NOTE]
> The script pins OpenVINO `2026.2` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release. From any new shell, source the matching `setupvars` script via the junction — `call "C:\Intel\openvino\setupvars.bat"` from `cmd`, or `& "C:\Intel\openvino\setupvars.ps1"` from PowerShell. If `winget` cannot register Visual Studio Build Tools on first run, install them once manually and re-run the script from an elevated **Developer Command Prompt for VS 2022**.
> The script pins OpenVINO `2026.2.1` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release. From any new shell, source the matching `setupvars` script via the junction — `call "C:\Intel\openvino\setupvars.bat"` from `cmd`, or `& "C:\Intel\openvino\setupvars.ps1"` from PowerShell. If `winget` cannot register Visual Studio Build Tools on first run, install them once manually and re-run the script from an elevated **Developer Command Prompt for VS 2022**.
</details>
+28 -1
View File
@@ -52,6 +52,32 @@ Supported EAGLE-3 draft models include:
For the full and up-to-date list of supported models, see #18039.
### DFlash (`draft-dflash`)
DFlash produces an entire block of draft tokens in a single forward pass (block diffusion) and
injects the target model's hidden states into the draft model's attention, instead of drafting one
token at a time. This keeps the draft model small while making drafting GPU-friendly. Unlike EAGLE-3
(a single-layer autoregressive draft), the DFlash draft uses several transformer layers but emits a
whole block per draft step.
The draft is a small block-diffusion model trained for a specific target (for example
`z-lab/Qwen3-4B-DFlash` for `Qwen/Qwen3-4B`). Convert it with `--target-model-dir` so it inherits the
target's tokenizer and token embeddings:
```bash
python convert_hf_to_gguf.py z-lab/Qwen3-4B-DFlash \
--target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-DFlash.gguf
llama-server -m Qwen3-4B.gguf -md Qwen3-4B-DFlash.gguf \
--spec-type draft-dflash --spec-draft-n-max 15 -fa on --jinja
```
`--spec-draft-n-max` is clamped to the draft model's trained block size.
See:
- #22105
### n-gram Cache (`ngram-cache`)
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
@@ -147,7 +173,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
### General Speculative Parameters
```
--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
--spec-type [none|draft-simple|draft-eagle3|draft-dflash|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
comma-separated list of types of speculative decoding to use
(default: none)
(env: LLAMA_ARG_SPEC_TYPE)
@@ -287,6 +313,7 @@ Specifies a comma-separated list of speculative decoding types to use.
| `none` | No speculative decoding (default) |
| `draft-simple` | Use a simple draft model for speculation |
| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states |
| `draft-dflash` | Use a DFlash block-diffusion draft model that emits a block per step |
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
| `ngram-cache` | Use n-gram cache lookup |
| `ngram-simple` | Use simple n-gram pattern matching |
+1 -1
View File
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 15)
set(GGML_VERSION_PATCH 2)
set(GGML_VERSION_PATCH 3)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
+7 -3
View File
@@ -1551,6 +1551,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
int split_backend_id = split->backend_id;
ggml_backend_t split_backend = sched->backends[split_backend_id];
ggml_backend_synchronize(split_backend);
// copy the input tensors to the split backend
for (int input_id = 0; input_id < split->n_inputs; input_id++) {
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
@@ -1561,15 +1563,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
} else {
} else if (!split_backend->iface.cpy_tensor_async) {
ggml_backend_synchronize(split_backend);
}
ggml_backend_tensor_copy(input, input_cpy);
ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
} else {
// wait for the split backend to finish using the input before overwriting it
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
} else {
} else if (!split_backend->iface.cpy_tensor_async) {
ggml_backend_synchronize(split_backend);
}
@@ -1674,6 +1676,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
}
}
ggml_backend_synchronize(split_backend);
if (!sched->callback_eval) {
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
if (ec != GGML_STATUS_SUCCESS) {
+2 -2
View File
@@ -75,12 +75,12 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
ay1 = GGML_F32_VEC_LOAD(y + i);
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
}
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmla on available elements only
if (np2 < n) {
svbool_t pg = svwhilelt_b32(np2, n);
ax1 = svld1_f32(pg, x + np2);
ay1 = svld1_f32(pg, y + np2);
sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
sum1 = svmla_f32_m(pg, sum1, ax1, ay1);
}
// reduce sum1,sum2 to sum1
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
+45
View File
@@ -386,6 +386,46 @@ static void ggml_cpy_f32_iq4_nl_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
// check if a same-type copy reduces to a 2D strided copy (height rows of width
// contiguous bytes), so it can use cudaMemcpy2DAsync instead of the scalar kernel
static bool ggml_cuda_cpy_as_memcpy_2d(const ggml_tensor * src0, const ggml_tensor * src1,
size_t & width, size_t & height, size_t & spitch, size_t & dpitch) {
// require matching shape: a reshaped copy maps elements by flat order, which the
// prefix walk below does not handle
if (src0->type != src1->type || !ggml_are_same_shape(src0, src1)) {
return false;
}
// grow the contiguous prefix block shared by both tensors
size_t block_nb = ggml_element_size(src0);
int d = 0;
for (; d < GGML_MAX_DIMS; ++d) {
if (src0->nb[d] != block_nb || src1->nb[d] != block_nb) {
break;
}
block_nb *= src0->ne[d];
}
// d == 0: nothing contiguous; d == GGML_MAX_DIMS: fully contiguous (handled by memcpy)
if (d == 0 || d == GGML_MAX_DIMS) {
return false;
}
// dim d carries the rows; everything above it must be a single element
for (int i = d + 1; i < GGML_MAX_DIMS; ++i) {
if (src0->ne[i] != 1) {
return false;
}
}
width = block_nb;
height = src0->ne[d];
spitch = src0->nb[d];
dpitch = src1->nb[d];
return spitch >= width && dpitch >= width;
}
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -421,6 +461,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
size_t mc_width = 0, mc_height = 0, mc_spitch = 0, mc_dpitch = 0;
if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
@@ -431,6 +473,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
{
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (ggml_cuda_cpy_as_memcpy_2d(src0, src1, mc_width, mc_height, mc_spitch, mc_dpitch)) {
CUDA_CHECK(cudaMemcpy2DAsync(src1_ddc, mc_dpitch, src0_ddc, mc_spitch,
mc_width, mc_height, cudaMemcpyDeviceToDevice, main_stream));
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
if (can_be_transposed) {
ggml_cpy_scalar_cuda<float, float, true>
+20 -4
View File
@@ -3192,11 +3192,24 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
// Enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA
// Excluding this path for HIP and MUSA as a precaution.
// According to the summary in https://github.com/ggml-org/llama.cpp/pull/20793#issuecomment-4275794315, this change is not beneficial for hip anyways.
// Additionally, there is a lot of anectodal evidence that hip/musa stream behavior might not always 1:1 match CUDA behavior.
// e.g. https://github.com/ROCm/rocm-systems/issues/5109
// It thus makes sense to exclude this path for HIP and MUSA. This PR was not aimed these backends, the majority of testing happened on CUDA.
// This can be revisited in the future if enabling copy_from_host benefits hip/MUSA, and if the PR author can extensively test on these backends.
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA)
const bool copy_from_host = false;
#else
const bool copy_from_host = ggml_backend_buffer_is_host(buf_src) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU;
#endif
if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) {
return false;
}
if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) {
if (!(copy_from_host || ggml_backend_buffer_is_cuda(buf_src)) || !ggml_backend_buffer_is_cuda(buf_dst)) {
return false;
}
@@ -3207,14 +3220,17 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context;
ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context;
if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) ||
!copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
#endif // NDEBUG
return false;
}
if (backend_src != backend_dst) {
if (copy_from_host) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream()));
} else if (backend_src != backend_dst) {
// copy on src stream
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
+55 -12
View File
@@ -2,6 +2,28 @@
#include <cstdint>
static __global__ void k_compute_out_prod_ptrs(
const float * src0_d, const float * src1_d, float * dst_d,
const float ** ptrs_a, const float ** ptrs_b, float ** ptrs_c,
const int64_t ne2, const int64_t ne3,
const int64_t dps2, const int64_t dps3,
const size_t s02, const size_t s03,
const size_t s12, const size_t s13,
const size_t s2, const size_t s3) {
const int64_t i2 = blockIdx.x*blockDim.x + threadIdx.x;
const int64_t i3 = blockIdx.y*blockDim.y + threadIdx.y;
if (i2 >= ne2 || i3 >= ne3) {
return;
}
const int64_t idx = i3*ne2 + i2;
ptrs_a[idx] = src0_d + (i3/dps3)*s03 + (i2/dps2)*s02;
ptrs_b[idx] = src1_d + i3 *s13 + i2 *s12;
ptrs_c[idx] = dst_d + i3 *s3 + i2 *s2;
}
void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -67,18 +89,39 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
&beta, dst_d + i3 *s3, ldc, s2,
batch_count));
}
} else if (ne2 > 1 || ne3 > 1) {
// dps2 > 1 (src0 broadcast along dim 2 with non-uniform stride) or multiple GEMMs
// along dim 3: compute per-GEMM pointers on the device and use a single batched GEMM.
GGML_ASSERT(ne3 > 0);
GGML_ASSERT(ne2 <= (int64_t) std::numeric_limits<int>::max() / ne3);
const int batch_count = (int) (ne2 * ne3);
ggml_cuda_pool_alloc<const float *> ptrs_a(ctx.pool(), batch_count);
ggml_cuda_pool_alloc<const float *> ptrs_b(ctx.pool(), batch_count);
ggml_cuda_pool_alloc< float *> ptrs_c(ctx.pool(), batch_count);
const dim3 block_dims(16, 16);
const dim3 grid_dims((ne2 + block_dims.x - 1)/block_dims.x, (ne3 + block_dims.y - 1)/block_dims.y);
k_compute_out_prod_ptrs<<<grid_dims, block_dims, 0, stream>>>(
src0_d, src1_d, dst_d,
ptrs_a.get(), ptrs_b.get(), ptrs_c.get(),
ne2, ne3, dps2, dps3, s02, s03, s12, s13, s2, s3);
CUDA_CHECK(cudaGetLastError());
CUBLAS_CHECK(
cublasSgemmBatched(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, ptrs_a.get(), lda,
ptrs_b.get(), ldb,
&beta, ptrs_c.get(), ldc,
batch_count));
} else {
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
src1_d + i3 *s13 + i2 *s12, ldb,
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
}
}
// ne2 == 1 && ne3 == 1: single GEMM
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d, lda,
src1_d, ldb,
&beta, dst_d, ldc));
}
}
+1
View File
@@ -48,6 +48,7 @@
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasSgemmBatched hipblasSgemmBatched
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
+1
View File
@@ -32,6 +32,7 @@
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasSgemmBatched mublasSgemmBatched
#define cublasSgemmStridedBatched mublasSgemmStridedBatched
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
+3
View File
@@ -192,7 +192,10 @@ set(GGML_OPENCL_KERNELS
mul_mm_f16_f32_kq_kqv
conv2d
conv2d_f16_f32
flash_attn_pre_f16
flash_attn_f32_f16
flash_attn_f32_q8_0
flash_attn_f32_q4_0
flash_attn_f16
flash_attn_f32
)
+91
View File
@@ -0,0 +1,91 @@
#pragma once
// Flash-attention per-(dk,dv) tile tuning for the Adreno OpenCL backend.
// Isolated from ggml-opencl.cpp so the tuning numbers are easy to find and
// edit; the FA dispatch and kernel-compile logic stay in the main file.
// This header is a file section — it is #included exactly once, at the point
// in ggml-opencl.cpp where the ggml logging macros are already in scope.
// Per-(dk, dv) FA config; shared by dispatch and supports_op.
struct ggml_opencl_fa_dim {
int dk; int dv; int bm; int bn; int n_split; int nkv_split_threshold;
};
// Split variant fires when n_kv >= threshold (threshold=0 -> always split).
// Default tuning covers Adreno 7xx/8xx mobile and X1-series laptop GPUs.
static const ggml_opencl_fa_dim g_fa_dims_adreno_default[] = {
{ 40, 40, 64, 32, 1, 0}, { 64, 64, 64, 32, 2, 64},
{ 80, 80, 64, 32, 2, 64}, { 96, 96, 64, 32, 2, 64},
{112, 112, 64, 32, 2, 64}, {128, 128, 64, 32, 2, 64},
{192, 128, 16, 16, 1, 0},
{192, 192, 16, 16, 1, 0},
{256, 256, 16, 16, 16, 0},
};
struct ggml_opencl_fa_dim_table {
const ggml_opencl_fa_dim * data;
size_t count;
const ggml_opencl_fa_dim * begin() const { return data; }
const ggml_opencl_fa_dim * end() const { return data + count; }
};
// Mutable copy of the active table; GGML_OPENCL_FA_TUNE patches entries here
// at backend init without touching the const source table.
static ggml_opencl_fa_dim g_fa_dims_runtime[
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0])];
static ggml_opencl_fa_dim_table g_opencl_fa_dims = {
g_fa_dims_adreno_default,
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]),
};
// GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,…] — patches matching entries
// in the active table at backend init, before the first FA kernel compiles.
// Unmatched (dk,dv) pairs are warned and ignored.
static void ggml_opencl_fa_apply_env_overrides() {
const char * e = std::getenv("GGML_OPENCL_FA_TUNE");
if (!e || !e[0]) {
return;
}
std::string s = e;
size_t pos = 0;
while (pos < s.size()) {
size_t comma = s.find(',', pos);
std::string entry = s.substr(pos, comma == std::string::npos ? std::string::npos : comma - pos);
int dk, dv, bm, bn, nsplit, thr;
if (std::sscanf(entry.c_str(), "%d:%d:%d:%d:%d:%d", &dk, &dv, &bm, &bn, &nsplit, &thr) == 6) {
bool patched = false;
for (size_t i = 0; i < g_opencl_fa_dims.count; ++i) {
ggml_opencl_fa_dim & d = g_fa_dims_runtime[i];
if (d.dk == dk && d.dv == dv) {
d.bm = bm; d.bn = bn; d.n_split = nsplit; d.nkv_split_threshold = thr;
GGML_LOG_INFO("ggml_opencl: FA tune override DK=%d DV=%d -> bm=%d bn=%d n_split=%d thr=%d\n",
dk, dv, bm, bn, nsplit, thr);
patched = true;
break;
}
}
if (!patched) {
GGML_LOG_WARN("ggml_opencl: FA tune override DK=%d DV=%d ignored (no matching dim)\n", dk, dv);
}
} else {
GGML_LOG_WARN("ggml_opencl: FA tune override entry malformed: '%s'\n", entry.c_str());
}
if (comma == std::string::npos) break;
pos = comma + 1;
}
}
// Copy the default table into the mutable runtime buffer and apply any
// GGML_OPENCL_FA_TUNE overrides. A per-generation table can be added here
// once it has been tuned on hardware.
static void ggml_cl_init_fa_dims_table() {
const size_t count = sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]);
for (size_t i = 0; i < count; ++i) {
g_fa_dims_runtime[i] = g_fa_dims_adreno_default[i];
}
g_opencl_fa_dims = { g_fa_dims_runtime, count };
ggml_opencl_fa_apply_env_overrides();
}
File diff suppressed because it is too large Load Diff
+152
View File
@@ -1582,6 +1582,158 @@ kernel void kernel_restore_block_q8_0(
}
}
// View-aware AoS q8_0 -> f32 dequant (f32/f32 FA path).
kernel void kernel_dequant_q8_0_f32_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global float * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
float d = vload_half(0, (global half *)block);
global char * qs = block + 2;
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global float * out = dst + (dst_row_base + blk_i0) * QK8_0;
for (int i = 0; i < QK8_0; ++i) {
out[i] = d * (float)qs[i];
}
}
// View-aware AoS q8_0 -> f16 dequant. Rows tight, batch strides may be gapped.
kernel void kernel_dequant_q8_0_f16_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global half * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
float d = vload_half(0, (global half *)block);
global char * qs = block + 2;
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global half * out = dst + (dst_row_base + blk_i0) * QK8_0;
for (int i = 0; i < QK8_0; ++i) {
out[i] = (half)(d * (float)qs[i]);
}
}
// View-aware AoS q4_0 -> f32 dequant (mirrors the q8_0 view variant).
kernel void kernel_dequant_q4_0_f32_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global float * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
float d = vload_half(0, (global half *)block);
global uchar * qs = (global uchar *)(block + 2);
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global float * out = dst + (dst_row_base + blk_i0) * QK4_0;
for (int i = 0; i < QK4_0/2; ++i) {
uchar byte = qs[i];
int q0 = (int)(byte & 0x0F) - 8;
int q1 = (int)(byte >> 4) - 8;
out[i] = d * (float)q0;
out[i + QK4_0/2] = d * (float)q1;
}
}
// View-aware AoS q4_0 -> f16 dequant (mirrors the q8_0 view variant).
kernel void kernel_dequant_q4_0_f16_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global half * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
float d = vload_half(0, (global half *)block);
global uchar * qs = (global uchar *)(block + 2);
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global half * out = dst + (dst_row_base + blk_i0) * QK4_0;
for (int i = 0; i < QK4_0/2; ++i) {
uchar byte = qs[i];
int q0 = (int)(byte & 0x0F) - 8;
int q1 = (int)(byte >> 4) - 8;
out[i] = (half)(d * (float)q0);
out[i + QK4_0/2] = (half)(d * (float)q1);
}
}
kernel void kernel_restore_block_q8_0_trans(
global uchar * src_q,
global half * src_d,
+75 -40
View File
@@ -4,14 +4,26 @@
#define ACC_TYPE4 float4
#define DATA_TYPE half
#define DATA_TYPE4 half4
#define CONVERT_ACC4(x) convert_float4(x)
#define CONVERT_DATA4(x) convert_half4(x)
#define CONVERT_ACC4(x) ((float4)((float)(x).s0, (float)(x).s1, (float)(x).s2, (float)(x).s3))
#define CONVERT_DATA4(x) ((half4)((half)(x).s0, (half)(x).s1, (half)(x).s2, (half)(x).s3))
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -81,18 +93,18 @@ __kernel void flash_attn_f16(
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -125,49 +137,72 @@ __kernel void flash_attn_f16(
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
const ACC_TYPE p0 = native_exp(s0 - m_new);
const ACC_TYPE p1 = native_exp(s1 - m_new);
const ACC_TYPE p2 = native_exp(s2 - m_new);
const ACC_TYPE p3 = native_exp(s3 - m_new);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1;
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
@@ -179,7 +214,7 @@ __kernel void flash_attn_f16(
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -191,12 +226,12 @@ __kernel void flash_attn_f16(
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
@@ -258,7 +293,7 @@ __kernel void flash_attn_f16_q1(
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
@@ -270,12 +305,12 @@ __kernel void flash_attn_f16_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -293,7 +328,7 @@ __kernel void flash_attn_f16_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -301,7 +336,7 @@ __kernel void flash_attn_f16_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -311,7 +346,7 @@ __kernel void flash_attn_f16_q1(
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -325,7 +360,7 @@ __kernel void flash_attn_f16_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -335,7 +370,7 @@ __kernel void flash_attn_f16_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -354,7 +389,7 @@ __kernel void flash_attn_f16_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -364,7 +399,7 @@ __kernel void flash_attn_f16_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}
+73 -38
View File
@@ -13,6 +13,18 @@
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -82,18 +94,18 @@ __kernel void flash_attn_f32(
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -126,49 +138,72 @@ __kernel void flash_attn_f32(
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
const ACC_TYPE p0 = native_exp(s0 - m_new);
const ACC_TYPE p1 = native_exp(s1 - m_new);
const ACC_TYPE p2 = native_exp(s2 - m_new);
const ACC_TYPE p3 = native_exp(s3 - m_new);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1;
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
@@ -180,7 +215,7 @@ __kernel void flash_attn_f32(
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -192,12 +227,12 @@ __kernel void flash_attn_f32(
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
@@ -259,7 +294,7 @@ __kernel void flash_attn_f32_q1(
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
@@ -271,12 +306,12 @@ __kernel void flash_attn_f32_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -294,7 +329,7 @@ __kernel void flash_attn_f32_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -302,7 +337,7 @@ __kernel void flash_attn_f32_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -312,7 +347,7 @@ __kernel void flash_attn_f32_q1(
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -326,7 +361,7 @@ __kernel void flash_attn_f32_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -336,7 +371,7 @@ __kernel void flash_attn_f32_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -355,7 +390,7 @@ __kernel void flash_attn_f32_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -365,7 +400,7 @@ __kernel void flash_attn_f32_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}
@@ -1,5 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_khr_subgroup_shuffle
#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable
#define HAS_SUBGROUP_SHUFFLE 1
#elif defined(cl_qcom_subgroup_shuffle)
#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
#define HAS_SUBGROUP_SHUFFLE 1
#endif
#define ACC_TYPE float
#define ACC_TYPE4 float4
#define Q_DATA_TYPE4 float4
@@ -12,9 +20,34 @@
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
// N_SPLIT>1 splits DK/DV across threads to cut per-thread register use.
#ifndef N_SPLIT
#define N_SPLIT 1
#endif
#define SPLIT_DK_VEC (DK_VEC / N_SPLIT)
#define SPLIT_DV_VEC (DV_VEC / N_SPLIT)
#if N_SPLIT > 1
#define WG_SIZE (BLOCK_M * N_SPLIT)
#else
#define WG_SIZE (BLOCK_M)
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -54,19 +87,38 @@ __kernel void flash_attn_f32_f16(
const int mask_ne2,
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
const ulong sinks_offset,
const global void * k_pad_void,
const global void * v_pad_void,
const global void * mask_pad_void,
const global char * blk,
const int n_kv_blocks,
const ulong mask_pad_nb1,
const ulong mask_pad_nb2,
const ulong mask_pad_nb3
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
const int head_batch_idx = get_global_id(1);
const int my_query_row = block_q_idx * BLOCK_M + tid;
#if N_SPLIT > 1
const int q_lane = tid / N_SPLIT;
const int split_idx = tid % N_SPLIT;
#else
const int q_lane = tid;
const int split_idx = 0;
#endif
const int my_query_row = block_q_idx * BLOCK_M + q_lane;
const int query_valid = my_query_row < n_q;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0;
const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
@@ -75,27 +127,41 @@ __kernel void flash_attn_f32_f16(
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
const global char* mask_pad_base = NULL;
if (mask_pad_void != NULL) {
mask_pad_base = (const global char*)mask_pad_void + mask_batch_idx * mask_pad_nb3 + mask_head_idx * mask_pad_nb2;
}
const global char* blk_base = NULL;
if (blk != NULL) {
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
blk_base = blk + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks + block_q_idx) * n_kv_blocks;
}
ACC_TYPE4 q_priv[DK_VEC];
if (my_query_row < n_q) {
ACC_TYPE4 q_priv[SPLIT_DK_VEC];
const int dk_off = split_idx * SPLIT_DK_VEC;
if (query_valid) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
FA_UNROLL
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[dk_off + i]);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
q_priv[i] = (ACC_TYPE4)(0.0f);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
ACC_TYPE4 o_acc[SPLIT_DV_VEC];
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -103,86 +169,369 @@ __kernel void flash_attn_f32_f16(
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
#if N_SPLIT > 1 && !defined(HAS_SUBGROUP_SHUFFLE)
__local ACC_TYPE local_partial[BLOCK_N][WG_SIZE];
__local ACC_TYPE local_p[BLOCK_M][BLOCK_N];
__local ACC_TYPE local_softmax_scale[BLOCK_M];
__local ACC_TYPE local_l_inv[BLOCK_M];
#endif
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
char blk_cur = 1;
if (blk_base != NULL) {
blk_cur = blk_base[k_start / BLOCK_N];
if (blk_cur == 0) continue;
}
const int use_kv_pad = k_pad_void != NULL && k_start + BLOCK_N > n_kv;
const int k_tile_start = use_kv_pad ? 0 : k_start;
const ulong k_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * k_nb1 : k_nb2;
const ulong k_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * k_tile_nb2 : k_nb3;
const ulong v_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * v_nb1 : v_nb2;
const ulong v_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * v_tile_nb2 : v_nb3;
const global char* k_tile_base = use_kv_pad ? (const global char*) k_pad_void : k_base;
const global char* v_tile_base = use_kv_pad ? (const global char*) v_pad_void : v_base;
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
const int row = i / DK_VEC;
const int col = i % DK_VEC;
const int k_row_idx = k_start + row;
if (k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
const int k_row_idx = k_tile_start + row;
if (use_kv_pad || k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_tile_nb3 + head_kv_idx * k_tile_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_tile_base + k_row_offset))[col];
} else {
l_k[row][col] = (KV_DATA_TYPE4)(0.0h);
}
}
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
const int row = i / DV_VEC;
const int col = i % DV_VEC;
const int v_row_idx = k_start + row;
if (v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
const int v_row_idx = k_tile_start + row;
if (use_kv_pad || v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_tile_nb3 + head_kv_idx * v_tile_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_tile_base + v_row_offset))[col];
} else {
l_v[row][col] = (KV_DATA_TYPE4)(0.0h);
}
}
barrier(CLK_LOCAL_MEM_FENCE);
if (my_query_row >= n_q) {
continue;
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
{
const int dv_off = split_idx * SPLIT_DV_VEC;
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE partial0 = 0.0f;
ACC_TYPE partial1 = 0.0f;
FA_UNROLL
for (int k = 0; k < SPLIT_DK_VEC; k++) {
const ACC_TYPE4 qk = q_priv[k];
ACC_TYPE4 dot0 = qk * CONVERT_KV_ACC4(l_k[j ][dk_off + k]);
ACC_TYPE4 dot1 = qk * CONVERT_KV_ACC4(l_k[j+1][dk_off + k]);
partial0 += dot0.s0 + dot0.s1 + dot0.s2 + dot0.s3;
partial1 += dot1.s0 + dot1.s1 + dot1.s2 + dot1.s3;
}
FA_UNROLL
for (int step = 1; step < N_SPLIT; step <<= 1) {
partial0 += sub_group_shuffle_xor(partial0, step);
partial1 += sub_group_shuffle_xor(partial1, step);
}
ACC_TYPE score0 = partial0 * scale;
ACC_TYPE score1 = partial1 * scale;
if (!query_valid) { score0 = FA_M_INIT; score1 = FA_M_INIT; }
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = FA_M_INIT;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = FA_M_INIT;
if (k_row1 >= n_kv) score1 = FA_M_INIT;
if (query_valid && mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
score0 += slope * (ACC_TYPE)mask_ptr[j];
score1 += slope * (ACC_TYPE)mask_ptr[j + 1];
} else {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
// far negative so the tile contributes 0, not exp(0)=1.
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE sp = native_exp(m_i - m_exp);
const ACC_TYPE p0 = native_exp(score0 - m_exp);
const ACC_TYPE p1 = native_exp(score1 - m_exp);
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = o_acc[i] * sp
+ p0 * CONVERT_KV_ACC4(l_v[j ][dv_off + i])
+ p1 * CONVERT_KV_ACC4(l_v[j+1][dv_off + i]);
}
l_i = l_i * sp + p0 + p1;
m_i = m_new;
}
}
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
#elif N_SPLIT > 1
// N_SPLIT>1 fallback (no shuffle): 3-phase local-memory reduction.
// Phase 1 partial dots for all BLOCK_N tokens.
for (int j = 0; j < BLOCK_N; ++j) {
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < SPLIT_DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][dk_off + k]), dot_acc);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
}
l_i = l_i * scale_prev + p0 + p1;
m_i = m_new;
local_partial[j][tid] =
dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3;
}
barrier(CLK_LOCAL_MEM_FENCE); // 1 barrier: partial dots visible
// Phase 2 split_idx==0 reduces partial sums and computes block softmax.
if (split_idx == 0) {
if (query_valid) {
ACC_TYPE m_new = m_i;
for (int j = 0; j < BLOCK_N; ++j) {
const int k_row = k_start + j;
ACC_TYPE score = 0.0f;
FA_UNROLL
for (int s = 0; s < N_SPLIT; s++) {
score += local_partial[j][q_lane * N_SPLIT + s];
}
score *= scale;
if (is_causal && k_row > (n_kv - n_q + my_query_row)) score = FA_M_INIT;
if (k_row >= n_kv) score = FA_M_INIT;
if (mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
score += slope * (ACC_TYPE)mask_ptr[j];
} else {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row < n_kv) score += slope * (ACC_TYPE)mask_ptr[k_row];
}
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_new = max(m_new, score);
local_p[q_lane][j] = score;
}
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE sp = native_exp(m_i - m_exp);
ACC_TYPE l_new = l_i * sp;
for (int j = 0; j < BLOCK_N; ++j) {
const ACC_TYPE p = native_exp(local_p[q_lane][j] - m_exp);
local_p[q_lane][j] = p;
l_new += p;
}
local_softmax_scale[q_lane] = sp;
l_i = l_new;
m_i = m_new;
} else {
local_softmax_scale[q_lane] = 1.0f;
for (int j = 0; j < BLOCK_N; ++j) local_p[q_lane][j] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
// Phase 3 V accumulate using broadcast probabilities.
{
const ACC_TYPE sp_block = local_softmax_scale[q_lane];
const int dv_off = split_idx * SPLIT_DV_VEC;
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] *= sp_block;
}
for (int j = 0; j < BLOCK_N; ++j) {
const ACC_TYPE p = local_p[q_lane][j];
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(l_v[j][dv_off + i]), o_acc[i]);
}
}
}
#else
// N_SPLIT==1: j+=4 unroll. Requires BLOCK_N % 4 == 0.
if (query_valid) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
s0 += slope * (ACC_TYPE)mask_ptr[j];
s1 += slope * (ACC_TYPE)mask_ptr[j + 1];
s2 += slope * (ACC_TYPE)mask_ptr[j + 2];
s3 += slope * (ACC_TYPE)mask_ptr[j + 3];
} else {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
}
if (logit_softcap > 0.0f) {
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
// far negative so the tile contributes 0, not exp(0)=1.
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE scale_prev = native_exp(m_i - m_exp);
const ACC_TYPE p0 = native_exp(s0 - m_exp);
const ACC_TYPE p1 = native_exp(s1 - m_exp);
const ACC_TYPE p2 = native_exp(s2 - m_exp);
const ACC_TYPE p3 = native_exp(s3 - m_exp);
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_KV_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_KV_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_KV_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
#endif
// End of tile: every thread must finish reading l_k/l_v before the
// next iteration's load overwrites them (WAR hazard on local memory).
barrier(CLK_LOCAL_MEM_FENCE);
}
if (my_query_row < n_q) {
// Write output.
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
if (query_valid) {
ACC_TYPE sinks_sp = 1.0f;
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
sinks_sp = exp(m_i - m_final);
l_i = l_i * sinks_sp + exp(m_sink - m_final);
m_i = m_final;
}
const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
const int dv_off = split_idx * SPLIT_DV_VEC;
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_inv > 0.0f) {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#elif N_SPLIT > 1
if (split_idx == 0) {
ACC_TYPE sinks_sp = 1.0f;
if (query_valid && sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
sinks_sp = exp(m_i - m_final);
l_i = l_i * sinks_sp + exp(m_sink - m_final);
m_i = m_final;
}
local_softmax_scale[q_lane] = sinks_sp;
local_l_inv[q_lane] = (query_valid && l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (query_valid) {
const ACC_TYPE sinks_sp = local_softmax_scale[q_lane];
const ACC_TYPE l_inv = local_l_inv[q_lane];
const int dv_off = split_idx * SPLIT_DV_VEC;
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_inv > 0.0f) {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#else
if (query_valid) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -194,17 +543,18 @@ __kernel void flash_attn_f32_f16(
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#endif
}
__kernel void flash_attn_f32_f16_q1(
@@ -258,13 +608,16 @@ __kernel void flash_attn_f32_f16_q1(
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
// Q is uniform across WG threads (n_q=1). Share via local memory to
// avoid per-thread q_priv[DK_VEC] dynamic-indexed private array that
// spills to DDR on Adreno.
__local ACC_TYPE4 q_shared[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
barrier(CLK_LOCAL_MEM_FENCE);
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -273,14 +626,14 @@ __kernel void flash_attn_f32_f16_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
@@ -296,7 +649,7 @@ __kernel void flash_attn_f32_f16_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -304,7 +657,7 @@ __kernel void flash_attn_f32_f16_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -314,9 +667,9 @@ __kernel void flash_attn_f32_f16_q1(
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
@@ -328,7 +681,7 @@ __kernel void flash_attn_f32_f16_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -338,7 +691,7 @@ __kernel void flash_attn_f32_f16_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -357,7 +710,7 @@ __kernel void flash_attn_f32_f16_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -367,7 +720,257 @@ __kernel void flash_attn_f32_f16_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
// Flash-decoding split pass. gid(2) = q_idx * n_splits + split_idx.
// Partial record per split: [m, l, O[DV]]. Merge kernel applies sink + norm.
#define FA_PARTIAL_FLOATS (2 + DV)
__kernel void flash_attn_f32_f16_q1_split(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
const float scale,
const int n_q,
const int n_kv,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void * mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3,
global float * partial_void,
const int n_splits,
const int kv_per_split
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
const int split_q_idx = get_global_id(2);
const int split_idx = split_q_idx % n_splits;
const int q_idx = split_q_idx / n_splits;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const int kv_start = split_idx * kv_per_split;
const int kv_end = min(kv_start + kv_per_split, n_kv);
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx)
* n_splits + split_idx);
global float * rec = partial_void + record_idx * record_stride;
global float4 * rec_o = (global float4 *) (rec + 2);
if (kv_start >= kv_end) {
// Empty split: leave sentinel partial for merge.
if (tid == 0) {
rec[0] = FA_M_INIT;
rec[1] = 0.0f;
}
return;
}
const global char * q_base = (const global char *) q_void + q_offset;
const global char * k_base = (const global char *) k_void + k_offset;
const global char * v_base = (const global char *) v_void + v_offset;
const global char * mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char *) mask_void + mask_offset +
mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 +
(ulong) q_idx * mask_nb1;
}
// Share Q via local memory (n_q=1 per split -> uniform across WG).
__local ACC_TYPE4 q_shared[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1;
const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset);
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
barrier(CLK_LOCAL_MEM_FENCE);
const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
// Pass 1a split-local max.
ACC_TYPE m_i = FA_M_INIT;
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; ++k) {
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
score += slope * (ACC_TYPE) mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_i = max(m_i, score);
}
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE m_c = local_m[0];
// Pass 1b softmax-weighted V accumulate.
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
const global KV_DATA_TYPE4 * v_ptr = (const global KV_DATA_TYPE4 *) (v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; ++k) {
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
score += slope * (ACC_TYPE) mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
const ACC_TYPE p = exp(score - m_c);
l_i += p;
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
}
__local ACC_TYPE local_l[Q1_WG_SIZE];
__local ACC_TYPE4 local_o[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE l_c = local_l[0];
if (tid == 0) {
rec[0] = (float) m_c;
rec[1] = (float) l_c;
}
for (int i = 0; i < DV_VEC; ++i) {
local_o[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o[tid] += local_o[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
rec_o[i] = local_o[0];
}
}
}
// FD Pass 2: merge per-split partials into final O. Empty splits drop via exp(-INF)=0.
__kernel void flash_attn_f32_merge(
const global float * partial_void,
global void * o_void,
const ulong o_offset,
const int n_head,
const int n_splits,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const global void * sinks_void,
const ulong sinks_offset,
const int n_q
) {
const int lane = get_local_id(0); // 0..DV_VEC-1
const int head_batch_idx = get_global_id(1);
const int q_idx = get_global_id(2);
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits;
const global float * rec0 = partial_void + record_idx_0 * record_stride;
__local ACC_TYPE m_final_shared;
__local ACC_TYPE l_final_shared;
if (lane == 0) {
ACC_TYPE m = FA_M_INIT;
for (int c = 0; c < n_splits; ++c) {
const ACC_TYPE m_c = rec0[c * record_stride + 0];
m = max(m, m_c);
}
ACC_TYPE m_sink = 0.0f;
bool has_sink = false;
if (sinks_void != NULL) {
const global ACC_TYPE * sinks_ptr =
(const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset);
m_sink = sinks_ptr[head_idx];
has_sink = true;
m = max(m, m_sink);
}
ACC_TYPE l = 0.0f;
for (int c = 0; c < n_splits; ++c) {
const ACC_TYPE m_c = rec0[c * record_stride + 0];
const ACC_TYPE l_c = rec0[c * record_stride + 1];
if (m_c > FA_M_INIT) {
l += l_c * exp(m_c - m);
}
}
if (has_sink) {
l += exp(m_sink - m);
}
m_final_shared = m;
l_final_shared = l;
}
barrier(CLK_LOCAL_MEM_FENCE);
const ACC_TYPE m_final = m_final_shared;
const ACC_TYPE l_final = l_final_shared;
const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f;
ACC_TYPE4 o = (ACC_TYPE4)(0.0f);
for (int c = 0; c < n_splits; ++c) {
const global float * rec_c = rec0 + c * record_stride;
const ACC_TYPE m_c = rec_c[0];
if (m_c <= FA_M_INIT) continue;
const global float4 * rec_oc = (const global float4 *) (rec_c + 2);
const ACC_TYPE scale_c = exp(m_c - m_final);
o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o);
}
o = o * l_inv;
const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1;
global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset);
o_row[lane] = CONVERT_O_DATA4(o);
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void flash_attn_kv_pad_f16(
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * k_pad_void,
global void * v_pad_void,
const int n_kv,
const int n_head_kv,
const int n_batch,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3
) {
const int row_idx = get_global_id(0);
const int head_kv_idx = get_global_id(1);
const int batch_idx = get_global_id(2);
if (row_idx >= BLOCK_N || head_kv_idx >= n_head_kv || batch_idx >= n_batch) {
return;
}
const int tail_start = n_kv - (n_kv % BLOCK_N);
const int src_row_idx = tail_start + row_idx;
const global char * k_src = (const global char *) k_void + k_offset;
const global char * v_src = (const global char *) v_void + v_offset;
global char * k_pad = (global char *) k_pad_void;
global char * v_pad = (global char *) v_pad_void;
const ulong k_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * k_nb1) + (ulong) row_idx * k_nb1;
const ulong v_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * v_nb1) + (ulong) row_idx * v_nb1;
if (src_row_idx < n_kv) {
const ulong k_src_offset = (ulong) batch_idx * k_nb3 + (ulong) head_kv_idx * k_nb2 + (ulong) src_row_idx * k_nb1;
const ulong v_src_offset = (ulong) batch_idx * v_nb3 + (ulong) head_kv_idx * v_nb2 + (ulong) src_row_idx * v_nb1;
for (ulong i = 0; i < k_nb1; ++i) {
k_pad[k_dst_offset + i] = k_src[k_src_offset + i];
}
for (ulong i = 0; i < v_nb1; ++i) {
v_pad[v_dst_offset + i] = v_src[v_src_offset + i];
}
} else {
for (ulong i = 0; i < k_nb1; ++i) {
k_pad[k_dst_offset + i] = 0;
}
for (ulong i = 0; i < v_nb1; ++i) {
v_pad[v_dst_offset + i] = 0;
}
}
}
__kernel void flash_attn_mask_pad_f16(
const global void * mask_void, ulong mask_offset,
global void * mask_pad_void,
const int n_q,
const int n_kv,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int col_idx = get_global_id(0);
const int q_row = get_global_id(1);
const int mask_slice = get_global_id(2);
if (col_idx >= BLOCK_N || q_row >= n_q || mask_slice >= mask_ne2 * mask_ne3) {
return;
}
const int tail_start = n_kv - (n_kv % BLOCK_N);
const int src_col_idx = tail_start + col_idx;
const int mask_head_idx = mask_slice % mask_ne2;
const int mask_batch_idx = mask_slice / mask_ne2;
const global char * mask_src_base = (const global char *) mask_void + mask_offset +
(ulong) mask_batch_idx * mask_nb3 +
(ulong) mask_head_idx * mask_nb2 +
(ulong) q_row * mask_nb1;
const global half * mask_src = (const global half *) mask_src_base;
global half * mask_pad = (global half *) mask_pad_void;
const ulong dst_idx =
(((ulong) mask_batch_idx * (ulong) mask_ne2 + (ulong) mask_head_idx) * (ulong) n_q + (ulong) q_row) * (ulong) BLOCK_N +
(ulong) col_idx;
mask_pad[dst_idx] = src_col_idx < n_kv ? mask_src[src_col_idx] : (half) (-INFINITY);
}
// Per-KV-tile mask class. 0=all -inf (skip tile), 1=mixed (apply mask),
// 2=all zero, no -inf (skip mask lookup). Causal diagonal tiles are class 1.
__kernel void flash_attn_blk_f16(
const global void * mask_void, ulong mask_offset,
global char * blk,
const int n_q,
const int n_kv,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int kv_block_idx = get_global_id(0);
const int q_block_idx = get_global_id(1);
const int mask_slice = get_global_id(2);
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
const int n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N;
if (kv_block_idx >= n_kv_blocks || q_block_idx >= n_q_blocks || mask_slice >= mask_ne2 * mask_ne3) {
return;
}
const int mask_head_idx = mask_slice % mask_ne2;
const int mask_batch_idx = mask_slice / mask_ne2;
const int q_start = q_block_idx * BLOCK_M;
const int k_start = kv_block_idx * BLOCK_N;
const int q_count = min(BLOCK_M, n_q - q_start);
const int k_count = min(BLOCK_N, n_kv - k_start);
const half neg_max_half = (half) (-65504.0f);
char has_unmasked = 0;
char has_masked = 0;
char has_nonzero = 0;
const global char * mask_base = (const global char *) mask_void + mask_offset +
(ulong) mask_batch_idx * mask_nb3 +
(ulong) mask_head_idx * mask_nb2;
for (int qi = 0; qi < q_count; ++qi) {
const global half * mask_row = (const global half *) (mask_base + (ulong) (q_start + qi) * mask_nb1) + k_start;
for (int ki = 0; ki < k_count; ++ki) {
const half v = mask_row[ki];
if (v <= neg_max_half) {
has_masked = 1;
} else {
has_unmasked = 1;
if (v != (half) 0.0f) {
has_nonzero = 1;
}
}
}
if (has_masked && has_unmasked) break; // mixed tile — short-circuit.
}
char res;
if (has_unmasked == 0) {
res = 0;
} else if (has_masked || has_nonzero) {
res = 1;
} else {
res = 2;
}
blk[((ulong) mask_slice * (ulong) n_q_blocks + (ulong) q_block_idx) * (ulong) n_kv_blocks + (ulong) kv_block_idx] = res;
}
+500
View File
@@ -158,6 +158,239 @@ kernel void kernel_set_rows_f32_i32(
}
}
// f32 -> q8_0 quantize set_rows. Block = half d + char qs[32].
#define QK8_0 32
inline void quantize_q8_0_block(global float * x, global char * qs, global half * d_out) {
float amax = 0.0f;
for (int j = 0; j < QK8_0; j++) {
amax = fmax(amax, fabs(x[j]));
}
float d = amax / 127.0f;
float id = (d != 0.0f) ? 127.0f / amax : 0.0f;
vstore_half(d, 0, d_out);
for (int j = 0; j < QK8_0; j++) {
qs[j] = (char)((int)round(x[j] * id));
}
}
kernel void kernel_set_rows_q8_0_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * y = dst_row + blk * (2 + QK8_0);
quantize_q8_0_block(x, y + 2, (global half *)y);
}
}
kernel void kernel_set_rows_q8_0_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * y = dst_row + blk * (2 + QK8_0);
quantize_q8_0_block(x, y + 2, (global half *)y);
}
}
// SoA q8_0 variants. dst_q: int8[QK8_0] per block; dst_d: fp16 scale per block.
// Layout matches kernel_convert_block_q8_0; block index follows dst element order.
kernel void kernel_set_rows_q8_0_soa_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * q = q_row + blk * QK8_0;
quantize_q8_0_block(x, q, d_row + blk);
}
}
kernel void kernel_set_rows_q8_0_soa_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * q = q_row + blk * QK8_0;
quantize_q8_0_block(x, q, d_row + blk);
}
}
kernel void kernel_set_rows_f16_i32(
global char * src0,
ulong offset0,
@@ -206,3 +439,270 @@ kernel void kernel_set_rows_f16_i32(
dst_row[ind] = src_row[ind];
}
}
// f32 -> q4_0 quantize set_rows. Block = half d + uchar qs[16] (shuffled
// nibbles: qs[j] low/high = elem j / j+16).
// Dequant: val[i] = d * (nibble_i - 8)
// nblk0 = number of q4_0 blocks per row = ne00 / 32.
#define QK4_0 32
#define Q4_0_BLOCK_SIZE 18
inline void quantize_q4_0_block(global float * x, global uchar * qs, global half * d_out) {
// Find the signed value with the largest absolute magnitude (matches ggml ref).
float max = 0.0f;
float amax = 0.0f;
for (int j = 0; j < QK4_0; j++) {
float v = x[j];
float a = fabs(v);
if (a > amax) {
amax = a;
max = v;
}
}
float d = max / -8.0f;
float id = (d != 0.0f) ? 1.0f / d : 0.0f;
vstore_half(d, 0, d_out);
for (int j = 0; j < QK4_0/2; j++) {
float x0 = x[j] * id;
float x1 = x[j + QK4_0/2] * id;
int i0 = (int)(x0 + 8.5f);
int i1 = (int)(x1 + 8.5f);
if (i0 < 0) i0 = 0;
if (i0 > 15) i0 = 15;
if (i1 < 0) i1 = 0;
if (i1 > 15) i1 = 15;
qs[j] = (uchar)i0 | ((uchar)i1 << 4);
}
}
kernel void kernel_set_rows_q4_0_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
global half * yd = (global half *)(y);
global uchar * yqs = (global uchar *)(y + 2);
quantize_q4_0_block(x, yqs, yd);
}
}
kernel void kernel_set_rows_q4_0_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
global half * yd = (global half *)(y);
global uchar * yqs = (global uchar *)(y + 2);
quantize_q4_0_block(x, yqs, yd);
}
}
// SoA variants for q4_0 dst. Used when the backend has split block_q4_0 records
// into separate quant (dst_q) and scale (dst_d) sub-buffers same pattern as
// the q8_0 SoA variants above.
//
// Layout (matches kernel_convert_block_q4_0, the "shuffled" variant):
// dst_q: contiguous 16 packed nibbles per block, block i at offset i * 16 bytes.
// dst_d: contiguous fp16 scales, block i at offset i * 2 bytes.
// Nibble layout inside each byte is unchanged from AoS: qs[j] low nibble = element j,
// qs[j] high nibble = element j+16. kernel_restore_block_q4_0 copies bytes as-is.
kernel void kernel_set_rows_q4_0_soa_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global uchar * qs = q_row + blk * (QK4_0/2);
global half * d_bk = d_row + blk;
quantize_q4_0_block(x, qs, d_bk);
}
}
kernel void kernel_set_rows_q4_0_soa_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global uchar * qs = q_row + blk * (QK4_0/2);
global half * d_bk = d_row + blk;
quantize_q4_0_block(x, qs, d_bk);
}
}
+3 -66
View File
@@ -1270,77 +1270,14 @@ void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecode
}
std::string GgmlOvDecoder::compute_op_type(const ggml_tensor * node) {
static const std::map<ggml_op, std::string> ops = {
{GGML_OP_NONE, "GGML_OP_NONE" },
{GGML_OP_ACC, "GGML_OP_ACC" },
{GGML_OP_ADD, "GGML_OP_ADD" },
{GGML_OP_ADD1, "GGML_OP_ADD1" },
{GGML_OP_ADD_ID, "GGML_OP_ADD_ID" },
{GGML_OP_CONCAT, "GGML_OP_CONCAT" },
{GGML_OP_CONT, "GGML_OP_CONT" },
{GGML_OP_DIV, "GGML_OP_DIV" },
{GGML_OP_DUP, "GGML_OP_DUP" },
{GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS" },
{GGML_OP_MUL, "GGML_OP_MUL" },
{GGML_OP_MUL_MAT, "GGML_OP_MUL_MAT" },
{GGML_OP_MUL_MAT_ID, "GGML_OP_MUL_MAT_ID" },
{GGML_OP_PERMUTE, "GGML_OP_PERMUTE" },
{GGML_OP_RESHAPE, "GGML_OP_RESHAPE" },
{GGML_OP_RMS_NORM, "GGML_OP_RMS_NORM" },
{GGML_OP_NORM, "GGML_OP_NORM" },
{GGML_OP_ROPE, "GGML_OP_ROPE" },
{GGML_OP_SCALE, "GGML_OP_SCALE" },
{GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" },
{GGML_OP_SUM_ROWS, "GGML_OP_SUM_ROWS" },
{GGML_OP_SUB, "GGML_OP_SUB" },
{GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE" },
{GGML_OP_VIEW, "GGML_OP_VIEW" },
{GGML_OP_SET_ROWS, "GGML_OP_SET_ROWS" },
{GGML_OP_CPY, "GGML_OP_CPY" },
{GGML_OP_FLASH_ATTN_EXT, "GGML_OP_FLASH_ATTN_EXT" },
{GGML_OP_L2_NORM, "GGML_OP_L2_NORM" },
{GGML_OP_CLAMP, "GGML_OP_CLAMP" },
{GGML_OP_PAD, "GGML_OP_PAD" },
{GGML_OP_SSM_CONV, "GGML_OP_SSM_CONV" },
{GGML_OP_GATED_DELTA_NET, "GGML_OP_GATED_DELTA_NET"},
{GGML_OP_ARGSORT, "GGML_OP_ARGSORT" },
{GGML_OP_REPEAT, "GGML_OP_REPEAT" },
{GGML_OP_IM2COL, "GGML_OP_IM2COL" }
};
static const std::map<ggml_unary_op, std::string> unary_ops = {
{GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS" },
{GGML_UNARY_OP_SGN, "GGML_UNARY_OP_SGN" },
{GGML_UNARY_OP_NEG, "GGML_UNARY_OP_NEG" },
{GGML_UNARY_OP_STEP, "GGML_UNARY_OP_STEP" },
{GGML_UNARY_OP_TANH, "GGML_UNARY_OP_TANH" },
{GGML_UNARY_OP_ELU, "GGML_UNARY_OP_ELU" },
{GGML_UNARY_OP_RELU, "GGML_UNARY_OP_RELU" },
{GGML_UNARY_OP_SIGMOID, "GGML_UNARY_OP_SIGMOID" },
{GGML_UNARY_OP_GELU, "GGML_UNARY_OP_GELU" },
{GGML_UNARY_OP_GELU_QUICK, "GGML_UNARY_OP_GELU_QUICK" },
{GGML_UNARY_OP_SILU, "GGML_UNARY_OP_SILU" },
{GGML_UNARY_OP_SOFTPLUS, "GGML_UNARY_OP_SOFTPLUS" },
{GGML_UNARY_OP_HARDSWISH, "GGML_UNARY_OP_HARDSWISH" },
{GGML_UNARY_OP_HARDSIGMOID, "GGML_UNARY_OP_HARDSIGMOID"},
{GGML_UNARY_OP_EXP, "GGML_UNARY_OP_EXP" },
{GGML_UNARY_OP_COUNT, "GGML_UNARY_OP_COUNT" }
};
static const std::map<ggml_glu_op, std::string> glu_ops = {
{GGML_GLU_OP_SWIGLU, "GGML_GLU_OP_SWIGLU"},
{GGML_GLU_OP_GEGLU, "GGML_GLU_OP_GEGLU" },
{GGML_GLU_OP_REGLU, "GGML_GLU_OP_REGLU" }
};
switch (node->op) {
case GGML_OP_UNARY:
return unary_ops.at(ggml_get_unary_op(node));
return std::string("GGML_UNARY_OP_") + ggml_unary_op_name(ggml_get_unary_op(node));
case GGML_OP_GLU:
return glu_ops.at(ggml_get_glu_op(node));
return std::string("GGML_GLU_OP_") + ggml_glu_op_name(ggml_get_glu_op(node));
default:
return ops.at(node->op);
return std::string("GGML_OP_") + ggml_op_name(node->op);
}
static const std::string unknown_op = "UNKNOWN_GGML_OP";
return unknown_op;
}
const std::string & GgmlOvDecoder::get_op_type(int node_idx) const {
+4
View File
@@ -1053,6 +1053,10 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
(op->ne[0] == 2 && op->ne[1] == 4 && op->ne[2] == 3 && op->ne[3] == 2)) {
return true;
}
// CPY into a strided view of a larger buffer (recurrent-state snapshots) not supported
if (op->view_src && ggml_nbytes(op) != ggml_nbytes(op->view_src)) {
return true;
}
break;
}
case GGML_OP_MUL_MAT: {
+19 -5
View File
@@ -17,6 +17,22 @@ namespace frontend {
namespace ggml {
namespace op {
static ov::Output<ov::Node> reshape_add_id_input_to_2d(const ov::Output<ov::Node> & input,
const ov::PartialShape & input_shape,
const std::vector<int> & dims) {
const auto actual_shape = input.get_partial_shape();
if (actual_shape.rank().is_static() && actual_shape.rank().get_length() == 2) {
return input;
}
if (input_shape.rank().is_static() && input_shape.rank().get_length() == 2) {
return input;
}
auto shape = std::make_shared<ov::op::v3::ShapeOf>(input, ov::element::i64);
return std::make_shared<ov::op::v1::Reshape>(input, get_dimensions(shape, dims), false);
}
OutputVector translate_add_id(const NodeContext & context) {
num_inputs_check(context, 3, 3);
@@ -28,11 +44,9 @@ OutputVector translate_add_id(const NodeContext & context) {
// input: [1, n_token, n_used, n_embd]
// bias: [1, 1, n_expert, n_embd]
// ids: [1, 1, n_token, n_used]
auto bias_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(bias, ov::element::i64);
auto ids_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(ids, ov::element::i64);
bias = std::make_shared<ov::op::v1::Reshape>(bias, get_dimensions(bias_shape_4d, {2, 3}), false);
ids = std::make_shared<ov::op::v1::Reshape>(ids, get_dimensions(ids_shape_4d, {2, 3}), false);
// Model bias constants may already be stored as [n_expert, n_embd].
bias = reshape_add_id_input_to_2d(bias, context.get_input_shape(1), {2, 3});
ids = reshape_add_id_input_to_2d(ids, context.get_input_shape(2), {2, 3});
if (ids.get_element_type() != ov::element::i32 && ids.get_element_type() != ov::element::i64) {
ids = std::make_shared<ov::op::v0::Convert>(ids, ov::element::i32);
@@ -3,8 +3,11 @@
#include "../utils.h"
#include <cstdint>
#include <limits>
#include <memory>
#include <openvino/core/node_output.hpp>
#include <openvino/op/add.hpp>
#include <openvino/op/clamp.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/sigmoid.hpp>
@@ -15,7 +18,7 @@ namespace frontend {
namespace ggml {
namespace op {
OutputVector translate_glu_swiglu(const NodeContext & context) {
static std::pair<ov::Output<ov::Node>, ov::Output<ov::Node>> get_glu_inputs(const NodeContext & context) {
num_inputs_check(context, 1, 2);
ov::Output<ov::Node> src0;
@@ -52,6 +55,12 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
std::swap(src0, src1);
}
return {src0, src1};
}
OutputVector translate_glu_swiglu(const NodeContext & context) {
auto [src0, src1] = get_glu_inputs(context);
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0);
auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid);
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1);
@@ -59,6 +68,27 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
return rename_outputs_with_suffix({res}, context.get_name());
}
OutputVector translate_glu_swiglu_oai(const NodeContext & context) {
auto [src0, src1] = get_glu_inputs(context);
const int32_t * params = context.get_output_op_params();
const float alpha = reinterpret_cast<const float *>(params)[2];
const float limit = reinterpret_cast<const float *>(params)[3];
auto gate = std::make_shared<ov::op::v0::Clamp>(src0, -std::numeric_limits<float>::infinity(), limit);
auto alpha_const = ov::op::v0::Constant::create(ov::element::f32, {}, {alpha});
auto scaled_gate = std::make_shared<ov::op::v1::Multiply>(gate, alpha_const);
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(scaled_gate);
auto out_glu = std::make_shared<ov::op::v1::Multiply>(gate, sigmoid);
auto up = std::make_shared<ov::op::v0::Clamp>(src1, -limit, limit);
auto one = ov::op::v0::Constant::create(ov::element::f32, {}, {1.0f});
auto up_plus_one = std::make_shared<ov::op::v1::Add>(up, one);
auto res = std::make_shared<ov::op::v1::Multiply>(out_glu, up_plus_one);
return rename_outputs_with_suffix({res}, context.get_name());
}
} // namespace op
} // namespace ggml
} // namespace frontend
@@ -2,23 +2,135 @@
#include "../op_table.h"
#include "../utils.h"
#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <openvino/op/bitwise_and.hpp>
#include <openvino/op/bitwise_right_shift.hpp>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/squeeze.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <vector>
namespace ov {
namespace frontend {
namespace ggml {
namespace op {
namespace {
std::shared_ptr<ov::op::v0::Constant> const_i64(const std::vector<int64_t> & values) {
return ov::op::v0::Constant::create(ov::element::i64, ov::Shape{values.size()}, values);
}
ov::Output<ov::Node> slice_axis(const ov::Output<ov::Node> & input, int64_t axis, int64_t begin, int64_t end) {
return std::make_shared<ov::op::v8::Slice>(input, const_i64({begin}), const_i64({end}), const_i64({1}),
const_i64({axis}));
}
ov::Output<ov::Node> translate_mul_mat_id_mxfp4_packed(const NodeContext & context,
ov::Output<ov::Node> expert_weights,
ov::Output<ov::Node> activations,
ov::Output<ov::Node> ids) {
auto packed_shape = expert_weights.get_partial_shape().to_shape();
FRONT_END_OP_CONVERSION_CHECK(packed_shape.size() == 5 && packed_shape[4] == 17,
"Expected packed MXFP4 expert weights with shape [1, n_expert, m, k_blocks, 17]");
const int64_t n_expert = static_cast<int64_t>(packed_shape[1]);
const int64_t rows = static_cast<int64_t>(packed_shape[2]);
const int64_t k_blocks = static_cast<int64_t>(packed_shape[3]);
const int64_t qk = 32;
const int64_t cols = k_blocks * qk;
auto packed_shape_4d = const_i64({n_expert, rows, k_blocks, 17});
expert_weights = std::make_shared<ov::op::v1::Reshape>(expert_weights, packed_shape_4d, false);
auto activations_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(activations, ov::element::i64);
auto ids_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(ids, ov::element::i64);
auto activations_shape_3d = get_dimensions(activations_shape_4d, {1, 2, 3});
auto ids_shape_2d = get_dimensions(ids_shape_4d, {2, 3});
activations = std::make_shared<ov::op::v1::Reshape>(activations, activations_shape_3d, false);
ids = std::make_shared<ov::op::v1::Reshape>(ids, ids_shape_2d, false);
if (ids.get_element_type() != ov::element::i32 && ids.get_element_type() != ov::element::i64) {
ids = std::make_shared<ov::op::v0::Convert>(ids, ov::element::i32);
}
auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
static const std::vector<float> f4e2m1_lut = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f,
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f};
std::vector<float> e8m0_lut(256);
for (size_t i = 0; i < e8m0_lut.size(); ++i) {
uint32_t bits = static_cast<uint32_t>(i) << 23;
memcpy(&e8m0_lut[i], &bits, sizeof(float));
}
e8m0_lut[0] = std::numeric_limits<float>::min() / 2.0f;
e8m0_lut[255] = std::numeric_limits<float>::quiet_NaN();
auto f4_lut = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{f4e2m1_lut.size()}, f4e2m1_lut);
auto scale_lut = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{e8m0_lut.size()}, e8m0_lut);
auto selected_packed_weights = std::make_shared<ov::op::v8::Gather>(expert_weights, ids, gather_axis);
auto scale_byte = slice_axis(selected_packed_weights, 4, 0, 1);
auto qs = slice_axis(selected_packed_weights, 4, 1, 17);
auto low = std::make_shared<ov::op::v13::BitwiseAnd>(
qs, ov::op::v0::Constant::create(ov::element::u8, ov::Shape{}, {0x0F}), ov::op::AutoBroadcastType::NUMPY);
auto high_shift = std::make_shared<ov::op::v15::BitwiseRightShift>(
qs, ov::op::v0::Constant::create(ov::element::u8, ov::Shape{}, {4}), ov::op::AutoBroadcastType::NUMPY);
auto nibbles = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{low, high_shift}, 4);
auto nibble_indices = std::make_shared<ov::op::v0::Convert>(nibbles, ov::element::i32);
auto weights_f32 = std::make_shared<ov::op::v8::Gather>(f4_lut, nibble_indices, gather_axis);
auto scale_indices = std::make_shared<ov::op::v0::Convert>(scale_byte, ov::element::i32);
auto scales_f32 = std::make_shared<ov::op::v8::Gather>(scale_lut, scale_indices, gather_axis);
ov::Output<ov::Node> selected_weights = std::make_shared<ov::op::v1::Multiply>(weights_f32, scales_f32,
ov::op::AutoBroadcastType::NUMPY);
auto ids_shape = std::make_shared<ov::op::v3::ShapeOf>(ids, ov::element::i64);
auto selected_weights_target_dims = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{get_dimensions(ids_shape, {0, 1}), const_i64({rows, cols})}, 0);
selected_weights = std::make_shared<ov::op::v1::Reshape>(selected_weights, selected_weights_target_dims, false);
auto activations_shape = std::make_shared<ov::op::v3::ShapeOf>(activations, ov::element::i64);
ov::Output<ov::Node> acts_target_dims = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{
get_dimensions(activations_shape, {0}),
get_dimensions(ids_shape, {1}),
get_dimensions(activations_shape, {2}),
},
0);
ov::Output<ov::Node> acts_broadcasted =
std::make_shared<ov::op::v3::Broadcast>(activations, acts_target_dims, ov::op::BroadcastType::BIDIRECTIONAL);
auto activations_expanded = std::make_shared<ov::op::v0::Unsqueeze>(acts_broadcasted, const_i64({2}));
ov::Output<ov::Node> result =
std::make_shared<ov::op::v0::MatMul>(activations_expanded, selected_weights, false, true);
auto batch_dim = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto row_dim = ov::op::v0::Constant::create(ov::element::i64, {1}, {rows});
auto result_target_dims = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{batch_dim, get_dimensions(ids_shape, {0, 1}), row_dim}, 0);
result = std::make_shared<ov::op::v1::Reshape>(result, result_target_dims, false);
const auto output_type = context.get_output_type();
if (result.get_element_type() != output_type) {
result = std::make_shared<ov::op::v0::Convert>(result, output_type);
}
return result;
}
} // namespace
OutputVector translate_mul_mat_id(const NodeContext & context) {
num_inputs_check(context, 3, 3);
@@ -26,6 +138,12 @@ OutputVector translate_mul_mat_id(const NodeContext & context) {
auto activations = process_view_input_new(context, 1);
auto ids = process_view_input_new(context, 2);
if (expert_weights.get_element_type() == ov::element::u8 && expert_weights.get_partial_shape().rank().is_static() &&
expert_weights.get_partial_shape().rank().get_length() == 5) {
return rename_outputs_with_suffix({translate_mul_mat_id_mxfp4_packed(context, expert_weights, activations, ids)},
context.get_name());
}
// OpenVINO sees GGML tensors in reversed dimension order:
// weights: [1, n_expert, m, k]
// activations: [1, n_tokens, n_used_or_1, k]
+65 -5
View File
@@ -6,12 +6,16 @@
#include <cstdint>
#include <cstring>
#include <memory>
#include <openvino/op/broadcast.hpp>
#include <openvino/frontend/exception.hpp>
#include <openvino/op/add.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/softmax.hpp>
#include <vector>
@@ -20,12 +24,31 @@ namespace frontend {
namespace ggml {
namespace op {
static bool is_static_one(const ov::Dimension & dim) {
return dim.is_static() && dim.get_length() == 1;
}
static bool same_static_dim(const ov::Dimension & lhs, const ov::Dimension & rhs) {
return lhs.is_static() && rhs.is_static() && lhs.get_length() == rhs.get_length();
}
static bool is_attention_sinks_input_shape(const ov::PartialShape & candidate, const ov::PartialShape & logits_shape) {
if (candidate.rank().is_dynamic() || logits_shape.rank().is_dynamic() || candidate.rank().get_length() != 4 ||
logits_shape.rank().get_length() != 4) {
return false;
}
return is_static_one(candidate[0]) && is_static_one(candidate[1]) && is_static_one(candidate[2]) &&
same_static_dim(candidate[3], logits_shape[1]);
}
// Reimplementation of GGML_OP_SOFT_MAX semantics for OpenVINO backend:
// 1) logits = src0 * scale
// 2) logits += mask (if provided)
// 3) softmax over the last dimension
// 3) append attention sinks as hidden logits (if provided)
// 4) softmax over the last dimension and remove the hidden sink column
OutputVector translate_soft_max(const NodeContext & context) {
num_inputs_check(context, 1, 2);
num_inputs_check(context, 1, 3);
float scale = 1.0f;
float max_bias = 0.0f;
@@ -33,6 +56,11 @@ OutputVector translate_soft_max(const NodeContext & context) {
memcpy(&max_bias, (float *) context.get_output_op_params() + 1, sizeof(float));
ov::Output<ov::Node> logits = context.get_input(0);
const bool second_input_is_sinks =
context.get_input_size() == 2 && is_attention_sinks_input_shape(context.get_input_shape(1), context.get_output_shape());
const bool has_mask = context.get_input_size() > 1 && !second_input_is_sinks;
const bool has_sinks = second_input_is_sinks || context.get_input_size() > 2;
const size_t sinks_input_idx = second_input_is_sinks ? 1 : 2;
// Apply scale first: logits = src0 * scale
if (scale != 1.0f) {
@@ -41,12 +69,12 @@ OutputVector translate_soft_max(const NodeContext & context) {
logits = std::make_shared<ov::op::v1::Multiply>(logits, scale_const);
}
FRONT_END_CHECK_IMPLEMENTED(!(max_bias > 0.0f && context.get_input_size() < 2),
FRONT_END_CHECK_IMPLEMENTED(!(max_bias > 0.0f && !has_mask),
"OpenVINO softmax ALiBi path requires mask input");
// Optional mask add: logits += mask
// For max_bias > 0 (ALiBi), apply per-head slope to mask before adding.
if (context.get_input_size() > 1) {
if (has_mask) {
ov::Output<ov::Node> mask = context.get_input(1);
// For stateful
@@ -94,8 +122,40 @@ OutputVector translate_soft_max(const NodeContext & context) {
logits = std::make_shared<ov::op::v1::Add>(logits, mask);
}
ov::Output<ov::Node> softmax_input = logits;
if (has_sinks) {
ov::Output<ov::Node> sinks = context.get_input(sinks_input_idx);
if (sinks.get_element_type() != logits.get_element_type()) {
sinks = std::make_shared<ov::op::v0::Convert>(sinks, logits.get_element_type());
}
auto sink_shape = ov::op::v0::Constant::create(ov::element::i64, {4}, {1, -1, 1, 1});
auto sinks_4d = std::make_shared<ov::op::v1::Reshape>(sinks, sink_shape, false);
auto logits_shape = std::make_shared<ov::op::v3::ShapeOf>(logits, ov::element::i64);
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
auto four = ov::op::v0::Constant::create(ov::element::i64, {1}, {4});
auto shape_axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto sink_prefix_shape = std::make_shared<ov::op::v8::Slice>(logits_shape, zero, three, one, shape_axis);
auto sink_last_dim = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto sink_broadcast_shape = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{sink_prefix_shape, sink_last_dim}, 0);
auto sink_column = std::make_shared<ov::op::v3::Broadcast>(sinks_4d, sink_broadcast_shape,
ov::op::BroadcastType::BIDIRECTIONAL);
softmax_input = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{logits, sink_column}, 3);
auto softmax_with_sink = std::make_shared<ov::op::v8::Softmax>(softmax_input, -1);
auto original_last_dim = std::make_shared<ov::op::v8::Slice>(logits_shape, three, four, one, shape_axis);
auto res = std::make_shared<ov::op::v8::Slice>(softmax_with_sink, zero, original_last_dim, one, three);
return rename_outputs_with_suffix({res}, context.get_name());
}
// Softmax along last dimension (equivalent to ggml softmax over ne[0]).
auto res = std::make_shared<ov::op::v8::Softmax>(logits, -1);
auto res = std::make_shared<ov::op::v8::Softmax>(softmax_input, -1);
return rename_outputs_with_suffix({res}, context.get_name());
}
@@ -47,6 +47,7 @@ std::unordered_map<std::string, CreatorFunction> get_supported_ops() {
{"GGML_UNARY_OP_TANH", op::translate_1to1_match_1_input<v0::Tanh> },
{"GGML_OP_VIEW", op::translate_view },
{"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu },
{"GGML_GLU_OP_SWIGLU_OAI", op::translate_glu_swiglu_oai },
{"GGML_GLU_OP_GEGLU", op::translate_glu_geglu },
{"GGML_OP_SET_ROWS", op::translate_set_rows },
{"GGML_OP_CPY", op::translate_cpy },
@@ -32,6 +32,7 @@ GGML_OP_CONVERTER(translate_soft_max);
GGML_OP_CONVERTER(translate_transpose);
GGML_OP_CONVERTER(translate_view);
GGML_OP_CONVERTER(translate_glu_swiglu);
GGML_OP_CONVERTER(translate_glu_swiglu_oai);
GGML_OP_CONVERTER(translate_glu_geglu);
GGML_OP_CONVERTER(translate_set_rows);
GGML_OP_CONVERTER(translate_cpy);
+102 -48
View File
@@ -2,8 +2,10 @@
#include "ggml-sycl/common.hpp"
#include "ggml-sycl/presets.hpp"
static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
static void norm_f32(const float* x, float* dst, const int ncols,
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
const int nrows = item_ct1.get_group_range(2);
const int nchannels = item_ct1.get_group_range(1);
@@ -16,16 +18,16 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
const int tid = item_ct1.get_local_id(2);
const int nwarps = nthreads / WARP_SIZE;
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row});
const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row});
x += strided_offset;
dst += packed_offset;
x += src_offset;
dst += dst_offset;
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
const float xi = x[col * src_stride_col];
mean_var.x() += xi;
mean_var.y() += xi * xi;
}
@@ -54,7 +56,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const int64_t
const float inv_std = sycl::rsqrt(var + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = (x[col] - mean) * inv_std;
dst[col * dst_stride_col] = (x[col * src_stride_col] - mean) * inv_std;
}
}
@@ -145,8 +147,10 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
}
}
static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
static void rms_norm_f32(const float* x, float* dst, const int ncols,
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
const int nrows = item_ct1.get_group_range(2);
const int nchannels = item_ct1.get_group_range(1);
@@ -160,17 +164,17 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
const int tid = item_ct1.get_local_id(2);
const int nwarps = nthreads / WARP_SIZE;
const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
const auto src_offset = calculate_offset<3>({src_stride_sample, src_stride_channel, src_stride_row}, {sample, channel, row});
const auto dst_offset = calculate_offset<3>({dst_stride_sample, dst_stride_channel, dst_stride_row}, {sample, channel, row});
x += strided_offset;
dst += packed_offset;
x += src_offset;
dst += dst_offset;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
const float xi = x[col * src_stride_col];
tmp += xi * xi;
}
@@ -198,14 +202,15 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6
const float scale = sycl::rsqrt(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
dst[col * dst_stride_col] = scale * x[col * src_stride_col];
}
}
template<int warp_size>
static void l2_norm_f32(const float * x, float * dst, const int ncols,
const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps,
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel,
const int64_t src_stride_sample, const int64_t dst_stride_col, const int64_t dst_stride_row,
const int64_t dst_stride_channel, const int64_t dst_stride_sample, const float eps,
const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) {
const int nrows = item_ct1.get_group_range(2);
const int nchannels = item_ct1.get_group_range(1);
@@ -215,13 +220,13 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols,
const int sample = item_ct1.get_group(0);
const int tid = item_ct1.get_local_id(2);
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
x += sample*src_stride_sample + channel*src_stride_channel + row*src_stride_row;
dst += sample*dst_stride_sample + channel*dst_stride_channel + row*dst_stride_row;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
const float xi = x[col * src_stride_col];
tmp += xi * xi;
}
@@ -229,12 +234,13 @@ static void l2_norm_f32(const float * x, float * dst, const int ncols,
const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
dst[col * dst_stride_col] = scale * x[col * src_stride_col];
}
}
static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
const float eps, queue_ptr stream, int device) {
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@@ -245,7 +251,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
norm_f32(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
@@ -265,7 +274,10 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
norm_f32(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
@@ -319,7 +331,9 @@ static void group_norm_f32_sycl(const float* x, float* dst,
}
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
const int64_t src_stride_col, const int64_t src_stride_row, const int64_t src_stride_channel, const int64_t src_stride_sample,
const int64_t dst_stride_col, const int64_t dst_stride_row, const int64_t dst_stride_channel, const int64_t dst_stride_sample,
const float eps, queue_ptr stream, int device) {
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
const sycl::range<3> global_dims(nsamples, nchannels, nrows);
@@ -330,7 +344,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
rms_norm_f32(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1, nullptr, WARP_SIZE);
});
});
}
@@ -350,7 +367,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const
sycl::nd_range<3>(global_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(WARP_SIZE)]] {
rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
rms_norm_f32(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
}
@@ -363,9 +383,14 @@ static void l2_norm_f32_sycl(const float * x,
const int nrows,
const int nchannels,
const int nsamples,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const int64_t src_stride_col,
const int64_t src_stride_row,
const int64_t src_stride_channel,
const int64_t src_stride_sample,
const int64_t dst_stride_col,
const int64_t dst_stride_row,
const int64_t dst_stride_channel,
const int64_t dst_stride_sample,
const float eps,
queue_ptr stream,
int device) {
@@ -379,7 +404,10 @@ static void l2_norm_f32_sycl(const float * x,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(warp_size)]] {
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1,
l2_norm_f32<warp_size>(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1,
nullptr, warp_size);
});
});
@@ -398,7 +426,9 @@ static void l2_norm_f32_sycl(const float * x,
block_dims),
[=](sycl::nd_item<3> item_ct1)
[[sycl::reqd_sub_group_size(warp_size)]] {
l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample,
l2_norm_f32<warp_size>(x, dst, ncols,
src_stride_col, src_stride_row, src_stride_channel, src_stride_sample,
dst_stride_col, dst_stride_row, dst_stride_channel, dst_stride_sample,
eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
});
});
@@ -421,12 +451,20 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
const size_t tdst = ggml_type_size(dst->type);
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
const int64_t ss0 = nb00 / ts0;
const int64_t ss1 = nb01 / ts0;
const int64_t ss2 = nb02 / ts0;
const int64_t ss3 = nb03 / ts0;
const int64_t ds0 = nb0 / tdst;
const int64_t ds1 = nb1 / tdst;
const int64_t ds2 = nb2 / tdst;
const int64_t ds3 = nb3 / tdst;
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03,
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device);
}
void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
@@ -465,11 +503,19 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
GGML_TENSOR_UNARY_OP_LOCALS
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
const size_t tdst = ggml_type_size(dst->type);
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
const int64_t ss0 = nb00 / ts0;
const int64_t ss1 = nb01 / ts0;
const int64_t ss2 = nb02 / ts0;
const int64_t ss3 = nb03 / ts0;
const int64_t ds0 = nb0 / tdst;
const int64_t ds1 = nb1 / tdst;
const int64_t ds2 = nb2 / tdst;
const int64_t ds3 = nb3 / tdst;
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03,
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, main_stream, ctx.device);
}
void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
@@ -644,13 +690,21 @@ void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
GGML_ASSERT(eps >= 0.0f);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
const size_t tdst = ggml_type_size(dst->type);
GGML_ASSERT(nb00 % ts0 == 0 && nb01 % ts0 == 0 && nb02 % ts0 == 0 && nb03 % ts0 == 0);
GGML_ASSERT(nb0 % tdst == 0 && nb1 % tdst == 0 && nb2 % tdst == 0 && nb3 % tdst == 0);
const int64_t ss0 = nb00 / ts0;
const int64_t ss1 = nb01 / ts0;
const int64_t ss2 = nb02 / ts0;
const int64_t ss3 = nb03 / ts0;
const int64_t ds0 = nb0 / tdst;
const int64_t ds1 = nb1 / tdst;
const int64_t ds2 = nb2 / tdst;
const int64_t ds3 = nb3 / tdst;
/*support both WARP_SIZE or WARP_32_SIZE in code
choose by hardware for better performance
*/
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device);
l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03,
ss0, ss1, ss2, ss3, ds0, ds1, ds2, ds3, eps, stream, ctx.device);
}
+2 -2
View File
@@ -126,7 +126,7 @@ static void soft_max_f32(const float * x,
break;
}
const float val = sycl::native::exp(vals[col] - max_val);
const float val = sycl::native::exp(sycl::max(vals[col] - max_val, -80.0f));
tmp += val;
vals[col] = val;
}
@@ -154,7 +154,7 @@ static void soft_max_f32(const float * x,
tmp = warp_reduce_sum<WARP_SIZE>(tmp);
}
if (sinks) {
tmp += sycl::native::exp(sinks[i02] - max_val);
tmp += sycl::native::exp(sycl::max(sinks[i02] - max_val, -80.0f));
}
const float inv_sum = 1.0f / tmp;
+19 -9
View File
@@ -308,6 +308,7 @@ enum vk_device_architecture {
AMD_RDNA1,
AMD_RDNA2,
AMD_RDNA3,
INTEL_XE1,
INTEL_XE2,
NVIDIA_PRE_TURING,
NVIDIA_TURING,
@@ -365,21 +366,26 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
bool subgroup_size_control = false;
bool integer_dot_product = false;
for (const auto& properties : ext_props) {
if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
subgroup_size_control = true;
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
integer_dot_product = true;
}
}
if (!subgroup_size_control) {
if (!subgroup_size_control || !integer_dot_product) {
return vk_device_architecture::OTHER;
}
vk::PhysicalDeviceProperties2 props2;
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
props2.pNext = &subgroup_size_control_props;
subgroup_size_control_props.pNext = &integer_dot_props;
device.getProperties2(&props2);
if (subgroup_size_control_props.minSubgroupSize == 16) {
@@ -388,6 +394,9 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
// https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
// https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
return vk_device_architecture::INTEL_XE2;
} else if (subgroup_size_control_props.minSubgroupSize == 8 &&
integer_dot_product && integer_dot_props.integerDotProduct4x8BitPackedSignedAccelerated) {
return vk_device_architecture::INTEL_XE1;
}
} else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
@@ -3837,7 +3846,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support) {
// Xe2/Xe3 with coopmat enabled - warptile performance tuning
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
@@ -4710,7 +4719,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
}
uint32_t rm_iq = 2 * rm_kq;
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
const bool use_subgroups = device->subgroup_arithmetic;
// Ensure a subgroup size >= 16 is available
const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;
@@ -6361,9 +6370,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
break;
case VK_VENDOR_ID_INTEL: {
// Current Windows driver does not expose BF16 support.
// We only want to use l_warptile if coopmat is available and is Xe2+
const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2;
const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat;
// We only want to use l_warptile if coopmat is available
const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && device->coopmat_support) : device->coopmat_support;
device->mul_mat_l[i] = use_l_warptile;
device->mul_mat_id_l[i] = use_l_warptile;
device->mul_mat_m[i] = true;
@@ -17890,9 +17898,9 @@ static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) {
static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
switch (props.vendorID) {
case VK_VENDOR_ID_INTEL:
// Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
// while some older hardware (ex. Arc A770) has performance regressions
return arch == vk_device_architecture::INTEL_XE2;
// Only allowing Xe2/Xe3 GPU and integrated Xe GPUs at the moment since older hardware (ex. Arc A770) has performance regressions.
return (arch == vk_device_architecture::INTEL_XE2) ||
(arch == vk_device_architecture::INTEL_XE1 && props.deviceType == vk::PhysicalDeviceType::eIntegratedGpu && driver_props.driverID == vk::DriverId::eIntelProprietaryWindows);
case VK_VENDOR_ID_AMD:
if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
// Workaround for AMD proprietary driver reporting support on all GPUs
@@ -17940,6 +17948,8 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev)
case 0xE20B: // B580
case 0xE211: // Pro B60
return 20;
case 0xB080: // PTL Xe3 LPG 2x6 (12 subslices)
return 12;
default:
return 0;
}
@@ -158,7 +158,7 @@ const uint32_t Csh_stride = BS_NPQ;
#ifdef COOPMAT
const uint32_t Csh_len = BS_K * Csh_stride;
#else
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 8; // 8 to workaround compiler bug
#endif
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
#endif
@@ -144,7 +144,7 @@ const uint32_t Csh_stride = BS_NPQ;
#ifdef COOPMAT
const uint32_t Csh_len = BS_K * Csh_stride;
#else
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 8; // 8 to workaround compiler bug
#endif
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
#endif
@@ -28,13 +28,10 @@ vec2 cache_b_ds;
#include "mul_mat_vecq_funcs.glsl"
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint col, const uint b_qs_idx) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
// Preload data_b block
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
const uint b_qs_idx = tid % (32 / K_PER_ITER);
const uint b_block_idx_outer = b_block_idx / 4;
const uint b_block_idx_inner = b_block_idx % 4;
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
@@ -91,35 +88,35 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
}
}
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
const uint col_stride = K_PER_ITER * BLOCK_SIZE;
uint num_iters = p.ncols / col_stride;
if (num_iters * col_stride + K_PER_ITER * tid < p.ncols) {
num_iters++;
}
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
uint i = 0;
while (i < unrolled_iters) {
const uint b_qs_idx = tid % (32 / K_PER_ITER);
uint col = tid * K_PER_ITER;
while (num_iters >= 4) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
[[unroll]] for (uint k = 0; k < 4; ++k) {
iter(temp, first_row, num_rows, col, b_qs_idx);
col += col_stride;
}
num_iters -= 4;
}
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
while (i < unrolled_iters) {
if (num_iters >= 2) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
iter(temp, first_row, num_rows, col, b_qs_idx);
col += col_stride;
iter(temp, first_row, num_rows, col, b_qs_idx);
col += col_stride;
num_iters -= 2;
}
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
if (num_iters > 0) {
iter(temp, first_row, num_rows, col, b_qs_idx);
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
@@ -42,7 +42,7 @@ float op_leaky_relu(float x) {
}
float op_step(float x) {
return x >= 0.0f ? 1.0f : 0.0f;
return x > 0.0f ? 1.0f : 0.0f;
}
float op_tanh(float x) {
+121 -2
View File
@@ -145,6 +145,7 @@ class Keys:
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
FULL_ATTENTION_INTERVAL = "{arch}.full_attention_interval"
HASH_LAYER_COUNT = "{arch}.hash_layer_count"
ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale"
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
@@ -156,6 +157,7 @@ class Keys:
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
TARGET_LAYERS = "{arch}.target_layers"
TARGET_HIDDEN_SIZE = "{arch}.target_hidden_size"
BLOCK_SIZE = "{arch}.block_size"
NORM_BEFORE_RESIDUAL = "{arch}.norm_before_residual"
class Attention:
@@ -179,8 +181,12 @@ class Keys:
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
OUTPUT_GROUP_COUNT = "{arch}.attention.output_group_count"
OUTPUT_LORA_RANK = "{arch}.attention.output_lora_rank"
OUTPUT_SCALE = "{arch}.attention.output_scale"
VALUE_SCALE = "{arch}.attention.value_scale"
COMPRESS_RATIOS = "{arch}.attention.compress_ratios"
COMPRESS_ROPE_FREQ_BASE = "{arch}.attention.compress_rope_freq_base"
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
@@ -195,6 +201,11 @@ class Keys:
KEY_LENGTH = "{arch}.attention.indexer.key_length"
TOP_K = "{arch}.attention.indexer.top_k"
class HyperConnection:
COUNT = "{arch}.hyper_connection.count"
SINKHORN_ITERATIONS = "{arch}.hyper_connection.sinkhorn_iterations"
EPSILON = "{arch}.hyper_connection.epsilon"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_COUNT_SWA = "{arch}.rope.dimension_count_swa"
@@ -469,6 +480,7 @@ class MODEL_ARCH(IntEnum):
DEEPSEEK2 = auto()
DEEPSEEK2OCR = auto()
DEEPSEEK32 = auto()
DEEPSEEK4 = auto()
CHATGLM = auto()
GLM4 = auto()
GLM4_MOE = auto()
@@ -517,6 +529,7 @@ class MODEL_ARCH(IntEnum):
PANGU_EMBED = auto()
MISTRAL3 = auto()
EAGLE3 = auto()
DFLASH = auto()
MISTRAL4 = auto()
PADDLEOCR = auto()
MIMO2 = auto()
@@ -553,6 +566,9 @@ class MODEL_TENSOR(IntEnum):
DENSE_2_OUT = auto() # embeddinggemma 2_Dense
DENSE_3_OUT = auto() # embeddinggemma 3_Dense
OUTPUT_NORM = auto()
HC_HEAD_FN = auto()
HC_HEAD_BASE = auto()
HC_HEAD_SCALE = auto()
ROPE_FREQS = auto()
ROPE_FACTORS_LONG = auto()
ROPE_FACTORS_SHORT = auto()
@@ -592,6 +608,7 @@ class MODEL_TENSOR(IntEnum):
FFN_DOWN_CHEXP = auto()
FFN_UP_CHEXP = auto()
FFN_EXP_PROBS_B = auto()
FFN_GATE_TID2EID = auto()
MOE_LATENT_DOWN = auto() # nemotron 3 super
MOE_LATENT_UP = auto() # nemotron 3 super
ATTN_Q_NORM = auto()
@@ -679,6 +696,20 @@ class MODEL_TENSOR(IntEnum):
ATTN_V_B = auto()
ATTN_Q_A_NORM = auto()
ATTN_KV_A_NORM = auto()
ATTN_KV = auto()
ATTN_KV_NORM = auto()
ATTN_OUT_A = auto()
ATTN_OUT_B = auto()
HC_ATTN_FN = auto()
HC_ATTN_BASE = auto()
HC_ATTN_SCALE = auto()
HC_FFN_FN = auto()
HC_FFN_BASE = auto()
HC_FFN_SCALE = auto()
ATTN_COMPRESSOR_WKV = auto()
ATTN_COMPRESSOR_WGATE = auto()
ATTN_COMPRESSOR_APE = auto()
ATTN_COMPRESSOR_NORM = auto()
FFN_SUB_NORM = auto()
ATTN_SUB_NORM = auto()
DEC_ATTN_NORM = auto()
@@ -740,6 +771,10 @@ class MODEL_TENSOR(IntEnum):
INDEXER_PROJ = auto()
INDEXER_ATTN_K = auto()
INDEXER_ATTN_Q_B = auto()
INDEXER_COMPRESSOR_WKV = auto()
INDEXER_COMPRESSOR_WGATE = auto()
INDEXER_COMPRESSOR_APE = auto()
INDEXER_COMPRESSOR_NORM = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
@@ -1025,6 +1060,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.DEEPSEEK2OCR: "deepseek2-ocr",
MODEL_ARCH.DEEPSEEK32: "deepseek32",
MODEL_ARCH.DEEPSEEK4: "deepseek4",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.GLM4: "glm4",
MODEL_ARCH.GLM4_MOE: "glm4moe",
@@ -1074,6 +1110,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.EAGLE3: "eagle3",
MODEL_ARCH.DFLASH: "dflash",
MODEL_ARCH.MISTRAL4: "mistral4",
MODEL_ARCH.PADDLEOCR: "paddleocr",
MODEL_ARCH.MIMO2: "mimo2",
@@ -1108,6 +1145,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.DENSE_2_OUT: "dense_2", # embeddinggemma 2_Dense
MODEL_TENSOR.DENSE_3_OUT: "dense_3", # embeddinggemma 2_Dense
MODEL_TENSOR.HC_HEAD_FN: "output_hc_fn",
MODEL_TENSOR.HC_HEAD_BASE: "output_hc_base",
MODEL_TENSOR.HC_HEAD_SCALE: "output_hc_scale",
MODEL_TENSOR.ROPE_FREQS: "rope_freqs",
MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long",
MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short",
@@ -1149,6 +1189,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.FFN_GATE_TID2EID: "blk.{bid}.ffn_gate_tid2eid",
MODEL_TENSOR.MOE_LATENT_DOWN: "blk.{bid}.ffn_latent_down", # nemotron 3 super
MODEL_TENSOR.MOE_LATENT_UP: "blk.{bid}.ffn_latent_up", # nemotron 3 super
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
@@ -1234,6 +1275,20 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b",
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_KV: "blk.{bid}.attn_kv",
MODEL_TENSOR.ATTN_KV_NORM: "blk.{bid}.attn_kv_a_norm",
MODEL_TENSOR.ATTN_OUT_A: "blk.{bid}.attn_output_a",
MODEL_TENSOR.ATTN_OUT_B: "blk.{bid}.attn_output_b",
MODEL_TENSOR.HC_ATTN_FN: "blk.{bid}.hc_attn_fn",
MODEL_TENSOR.HC_ATTN_BASE: "blk.{bid}.hc_attn_base",
MODEL_TENSOR.HC_ATTN_SCALE: "blk.{bid}.hc_attn_scale",
MODEL_TENSOR.HC_FFN_FN: "blk.{bid}.hc_ffn_fn",
MODEL_TENSOR.HC_FFN_BASE: "blk.{bid}.hc_ffn_base",
MODEL_TENSOR.HC_FFN_SCALE: "blk.{bid}.hc_ffn_scale",
MODEL_TENSOR.ATTN_COMPRESSOR_WKV: "blk.{bid}.attn_compressor_kv",
MODEL_TENSOR.ATTN_COMPRESSOR_WGATE: "blk.{bid}.attn_compressor_gate",
MODEL_TENSOR.ATTN_COMPRESSOR_APE: "blk.{bid}.attn_compressor_ape",
MODEL_TENSOR.ATTN_COMPRESSOR_NORM: "blk.{bid}.attn_compressor_norm",
MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm",
MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm",
MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm",
@@ -1295,6 +1350,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj",
MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k",
MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b",
MODEL_TENSOR.INDEXER_COMPRESSOR_WKV: "blk.{bid}.indexer_compressor_kv",
MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE: "blk.{bid}.indexer_compressor_gate",
MODEL_TENSOR.INDEXER_COMPRESSOR_APE: "blk.{bid}.indexer_compressor_ape",
MODEL_TENSOR.INDEXER_COMPRESSOR_NORM: "blk.{bid}.indexer_compressor_norm",
# vision
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
@@ -3135,6 +3194,49 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.DEEPSEEK4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.HC_HEAD_FN,
MODEL_TENSOR.HC_HEAD_BASE,
MODEL_TENSOR.HC_HEAD_SCALE,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_SINKS,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV,
MODEL_TENSOR.ATTN_KV_NORM,
MODEL_TENSOR.ATTN_OUT_A,
MODEL_TENSOR.ATTN_OUT_B,
MODEL_TENSOR.HC_ATTN_FN,
MODEL_TENSOR.HC_ATTN_BASE,
MODEL_TENSOR.HC_ATTN_SCALE,
MODEL_TENSOR.HC_FFN_FN,
MODEL_TENSOR.HC_FFN_BASE,
MODEL_TENSOR.HC_FFN_SCALE,
MODEL_TENSOR.ATTN_COMPRESSOR_WKV,
MODEL_TENSOR.ATTN_COMPRESSOR_WGATE,
MODEL_TENSOR.ATTN_COMPRESSOR_APE,
MODEL_TENSOR.ATTN_COMPRESSOR_NORM,
MODEL_TENSOR.INDEXER_PROJ,
MODEL_TENSOR.INDEXER_ATTN_Q_B,
MODEL_TENSOR.INDEXER_COMPRESSOR_WKV,
MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE,
MODEL_TENSOR.INDEXER_COMPRESSOR_APE,
MODEL_TENSOR.INDEXER_COMPRESSOR_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_TID2EID,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.ERNIE4_5_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -4086,6 +4188,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FC,
MODEL_TENSOR.D2T,
],
MODEL_ARCH.DFLASH: [
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FC,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.MISTRAL4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -4418,8 +4536,9 @@ class GGMLQuantizationType(IntEnum):
class ExpertGatingFuncType(IntEnum):
SOFTMAX = 1
SIGMOID = 2
SOFTMAX = 1
SIGMOID = 2
SQRTSOFTPLUS = 4
# TODO: add GGMLFileType from ggml_ftype in ggml.h
+36
View File
@@ -715,6 +715,9 @@ class GGUFWriter:
def add_full_attention_interval(self, interval: int) -> None:
self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval)
def add_hash_layer_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.HASH_LAYER_COUNT.format(arch=self.arch), count)
def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
if isinstance(length, int):
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
@@ -940,6 +943,39 @@ class GGUFWriter:
def add_sliding_window(self, value: int) -> None:
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)
def add_block_size(self, value: int) -> None:
self.add_uint32(Keys.LLM.BLOCK_SIZE.format(arch=self.arch), value)
def add_target_layers(self, value: Sequence[int]) -> None:
self.add_array(Keys.LLM.TARGET_LAYERS.format(arch=self.arch), value)
def add_target_hidden_size(self, value: int) -> None:
self.add_uint32(Keys.LLM.TARGET_HIDDEN_SIZE.format(arch=self.arch), value)
def add_norm_before_residual(self, value: bool) -> None:
self.add_bool(Keys.LLM.NORM_BEFORE_RESIDUAL.format(arch=self.arch), value)
def add_attention_output_group_count(self, count: int) -> None:
self.add_uint32(Keys.Attention.OUTPUT_GROUP_COUNT.format(arch=self.arch), count)
def add_attention_output_lora_rank(self, length: int) -> None:
self.add_uint32(Keys.Attention.OUTPUT_LORA_RANK.format(arch=self.arch), length)
def add_attention_compress_ratios(self, values: Sequence[int]) -> None:
self.add_array(Keys.Attention.COMPRESS_RATIOS.format(arch=self.arch), values)
def add_attention_compress_rope_freq_base(self, value: float) -> None:
self.add_float32(Keys.Attention.COMPRESS_ROPE_FREQ_BASE.format(arch=self.arch), value)
def add_hyper_connection_count(self, count: int) -> None:
self.add_uint32(Keys.HyperConnection.COUNT.format(arch=self.arch), count)
def add_hyper_connection_sinkhorn_iterations(self, count: int) -> None:
self.add_uint32(Keys.HyperConnection.SINKHORN_ITERATIONS.format(arch=self.arch), count)
def add_hyper_connection_epsilon(self, value: float) -> None:
self.add_float32(Keys.HyperConnection.EPSILON.format(arch=self.arch), value)
def add_attention_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
+5
View File
@@ -1283,6 +1283,11 @@ class TensorNameMap:
MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm", # t5
"layer_norm", # neobert
"model.hidden_norm", # dflash
),
MODEL_TENSOR.FC: (
"model.fc", # dflash
),
MODEL_TENSOR.CLS: (
@@ -0,0 +1,112 @@
{%- if not add_generation_prompt is defined -%}
{%- set add_generation_prompt = false -%}
{%- endif -%}
{%- if not thinking is defined -%}
{%- if enable_thinking is defined -%}
{%- set thinking = enable_thinking -%}
{%- else -%}
{%- set thinking = false -%}
{%- endif -%}
{%- endif -%}
{%- set dsml_token = 'DSML' -%}
{%- set thinking_start_token = '<think>' -%}
{%- set thinking_end_token = '</think>' -%}
{%- set tools_header = '## Tools\n\nYou have access to a set of tools to help answer the user\'s question. You can invoke tools by writing a "<' + dsml_token + 'tool_calls>" block like the following:\n\n<' + dsml_token + 'tool_calls>\n<' + dsml_token + 'invoke name="$TOOL_NAME">\n<' + dsml_token + 'parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</' + dsml_token + 'parameter>\n...\n</' + dsml_token + 'invoke>\n<' + dsml_token + 'invoke name="$TOOL_NAME2">\n...\n</' + dsml_token + 'invoke>\n</' + dsml_token + 'tool_calls>\n\nString parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.\n\nIf thinking_mode is enabled (triggered by ' + thinking_start_token + '), you MUST output your complete reasoning inside ' + thinking_start_token + '...' + thinking_end_token + ' BEFORE any tool calls or final response.\n\nOtherwise, output directly after ' + thinking_end_token + ' with tool calls or final response.\n\n### Available Tool Schemas\n\n' -%}
{%- set tools_footer = '\nYou MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.\n' -%}
{%- set ns = namespace(system_prompt='', is_first_sp=true) -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- if ns.is_first_sp -%}
{%- set ns.system_prompt = ns.system_prompt + (message['content'] or '') -%}
{%- set ns.is_first_sp = false -%}
{%- else -%}
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + (message['content'] or '') -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if tools is defined and tools -%}
{%- set ts = namespace(schemas='') -%}
{%- for tool in tools -%}
{%- if tool['type'] == 'function' -%}
{%- set ts.schemas = ts.schemas + (tool['function'] | tojson) + '\n' -%}
{%- endif -%}
{%- endfor -%}
{%- if ns.system_prompt -%}
{%- set ns.system_prompt = ns.system_prompt + '\n\n' + tools_header + ts.schemas + tools_footer -%}
{%- else -%}
{%- set ns.system_prompt = tools_header + ts.schemas + tools_footer -%}
{%- endif -%}
{%- endif -%}
{{- bos_token -}}
{{- ns.system_prompt -}}
{%- set last_user_idx = namespace(value=-1) -%}
{%- for message in messages -%}
{%- if message['role'] == 'user' or message['role'] == 'developer' or message['role'] == 'tool' -%}
{%- set last_user_idx.value = loop.index0 -%}
{%- endif -%}
{%- endfor -%}
{%- set state = namespace(in_user=false) -%}
{%- for message in messages -%}
{%- if message['role'] == 'user' or message['role'] == 'developer' -%}
{%- if state.in_user -%}
{{- '\n\n' -}}
{%- else -%}
{{- '<User>' -}}
{%- set state.in_user = true -%}
{%- endif -%}
{{- message['content'] or '' -}}
{%- elif message['role'] == 'tool' -%}
{%- if state.in_user -%}
{{- '\n\n' -}}
{%- else -%}
{{- '<User>' -}}
{%- set state.in_user = true -%}
{%- endif -%}
{{- '<tool_result>' + (message['content'] or '') + '</tool_result>' -}}
{%- elif message['role'] == 'assistant' -%}
{%- set state.in_user = false -%}
{{- '<Assistant>' -}}
{%- set is_after_last_user = loop.index0 > last_user_idx.value -%}
{%- if is_after_last_user and thinking -%}
{{- thinking_start_token -}}
{%- if message['reasoning_content'] is defined and message['reasoning_content'] -%}
{{- message['reasoning_content'] -}}
{%- endif -%}
{{- thinking_end_token -}}
{%- else -%}
{{- thinking_end_token -}}
{%- endif -%}
{%- if message['content'] is defined and message['content'] -%}
{{- message['content'] -}}
{%- endif -%}
{%- if message['tool_calls'] -%}
{{- '\n\n<' + dsml_token + 'tool_calls>\n' -}}
{%- for tool in message['tool_calls'] -%}
{%- set func = tool['function'] -%}
{{- '<' + dsml_token + 'invoke name="' + func['name'] + '">\n' -}}
{%- set args = func['arguments'] -%}
{%- if args is string -%}
{%- set args = args | from_json -%}
{%- endif -%}
{%- for key, val in args.items() -%}
{%- if val is string -%}
{{- '<' + dsml_token + 'parameter name="' + key + '" string="true">' + val + '</' + dsml_token + 'parameter>\n' -}}
{%- else -%}
{{- '<' + dsml_token + 'parameter name="' + key + '" string="false">' + (val | tojson) + '</' + dsml_token + 'parameter>\n' -}}
{%- endif -%}
{%- endfor -%}
{{- '</' + dsml_token + 'invoke>\n' -}}
{%- endfor -%}
{{- '</' + dsml_token + 'tool_calls>' -}}
{%- endif -%}
{{- '<end▁of▁sentence>' -}}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- '<Assistant>' -}}
{%- if thinking -%}
{{- thinking_start_token -}}
{%- else -%}
{{- thinking_end_token -}}
{%- endif -%}
{%- endif -%}
+179
View File
@@ -0,0 +1,179 @@
{{- bos_token }}{%- if tools %}
{%- set tool_definitions %}
{{- "# Tools\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson(ensure_ascii=False) }}
{%- endfor %}
{{- '\n</tools>\n\nTool usage guidelines:\n- You may call zero or more functions. If no function calls are needed, just answer normally and do not include any <function ... </function>.\n- When calling a function, return an XML object within <function ... </function> using:\n<function name="function-name"><param name="param-name">param-value</param></function>\n- param-value may be multi-line. If it contains <, & or newline characters, wrap it in a CDATA block: <param name="param-name"><![CDATA[...multi-line value...]]></param>' }}
{%- endset %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{%- if '<tool_def_sep>' in messages[0].content %}
{{- messages[0].content.replace('<tool_def_sep>', tool_definitions) }}
{%- else %}
{{- messages[0].content + '\n\n' + tool_definitions }}
{%- endif %}
{%- else %}
{{- tool_definitions.lstrip() }}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if message.tool_calls %}
{%- set content_parts = content.split('<tool_sep>') %}
{%- set processed_content = content_parts[0] %}
{%- set tool_calls_count = message.tool_calls|length %}
{%- set tool_sep_count = content_parts|length - 1 %}
{%- set min_count = [tool_calls_count, tool_sep_count]|min %}
{%- for i in range(1, content_parts|length) %}
{%- set tool_index = i - 1 %}
{%- if tool_index < tool_calls_count %}
{%- set tool_call = message.tool_calls[tool_index] %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- set single_tool_xml %}
{{- '<function name="' ~ tool_call.name ~ '">' }}
{%- if tool_call.arguments %}
{%- set args_dict = tool_call.arguments %}
{%- for param_name, param_value in args_dict.items() %}
{{- '<param name="' ~ param_name ~ '">' }}
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
{{- '<![CDATA[' + param_value + ']]>' }}
{%- else %}
{{- param_value }}
{%- endif %}
{{- '</param>' }}
{%- endfor %}
{%- endif %}
{{- '</function>' }}
{%- endset %}
{%- set processed_content = processed_content + single_tool_xml + content_parts[i] %}
{%- else %}
{%- set processed_content = processed_content + content_parts[i] %}
{%- endif %}
{%- endfor %}
{%- if tool_calls_count > tool_sep_count %}
{%- for remaining_index in range(tool_sep_count, tool_calls_count) %}
{%- set tool_call = message.tool_calls[remaining_index] %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{%- set remaining_tool_xml %}
{{- '<function name="' ~ tool_call.name ~ '">' }}
{%- if tool_call.arguments %}
{%- set args_dict = tool_call.arguments %}
{%- for param_name, param_value in args_dict.items() %}
{{- '<param name="' ~ param_name ~ '">' }}
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
{{- '<![CDATA[' + param_value + ']]>' }}
{%- else %}
{{- param_value }}
{%- endif %}
{{- '</param>' }}
{%- endfor %}
{%- endif %}
{{- '</function>' }}
{%- endset %}
{%- set processed_content = processed_content + remaining_tool_xml %}
{%- endfor %}
{%- endif %}
{%- set content = processed_content %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if reasoning_content %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls and not has_tool_sep %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<function name="' ~ tool_call.name ~ '">' }}
{%- if tool_call.arguments %}
{%- set args_dict = tool_call.arguments %}
{%- for param_name, param_value in args_dict.items() %}
{{- '<param name="' ~ param_name ~ '">' }}
{%- if param_value is string and ('<' in param_value or '&' in param_value or '\n' in param_value) %}
{{- '<![CDATA[' + param_value + ']]>' }}
{%- else %}
{{- param_value }}
{%- endif %}
{{- '</param>' }}
{%- endfor %}
{%- endif %}
{{- '</function>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{%- if message.content is string %}
{{- content }}
{%- else %}
{{- message.content | tojson(ensure_ascii=False) }}
{%- endif %}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined %}
{%- if enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- elif enable_thinking is true %}
{{- '<think>\n' }}
{%- endif %}
{%- endif %}
{%- endif %}
+1 -1
View File
@@ -1 +1 @@
707321c4cf6d21cb4bc831aa8b687dbf01a521ce
eced84c86f8b012c752c016f7fe789adea168e1e
+1
View File
@@ -25,6 +25,7 @@ add_library(llama
llama-kv-cache.cpp
llama-kv-cache-iswa.cpp
llama-kv-cache-dsa.cpp
llama-kv-cache-dsv4.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-hybrid-iswa.cpp
+57
View File
@@ -77,6 +77,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" },
{ LLM_ARCH_DEEPSEEK32, "deepseek32" },
{ LLM_ARCH_DEEPSEEK4, "deepseek4" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_GLM4, "glm4" },
{ LLM_ARCH_GLM4_MOE, "glm4moe" },
@@ -129,6 +130,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_EAGLE3, "eagle3" },
{ LLM_ARCH_DFLASH, "dflash" },
{ LLM_ARCH_MISTRAL4, "mistral4" },
{ LLM_ARCH_PADDLEOCR, "paddleocr" },
{ LLM_ARCH_MIMO2, "mimo2" },
@@ -249,9 +251,19 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" },
{ LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" },
{ LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" },
{ LLM_KV_ATTENTION_OUTPUT_GROUP_COUNT, "%s.attention.output_group_count" },
{ LLM_KV_ATTENTION_OUTPUT_LORA_RANK, "%s.attention.output_lora_rank" },
{ LLM_KV_ATTENTION_COMPRESS_ROPE_FREQ_BASE, "%s.attention.compress_rope_freq_base" },
{ LLM_KV_ATTENTION_COMPRESS_RATIOS, "%s.attention.compress_ratios" },
{ LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" },
{ LLM_KV_ATTENTION_RECURRENT_LAYERS, "%s.attention.recurrent_layers" },
{ LLM_KV_HYPER_CONNECTION_COUNT, "%s.hyper_connection.count" },
{ LLM_KV_HYPER_CONNECTION_SINKHORN_ITERATIONS, "%s.hyper_connection.sinkhorn_iterations" },
{ LLM_KV_HYPER_CONNECTION_EPSILON, "%s.hyper_connection.epsilon" },
{ LLM_KV_HASH_LAYER_COUNT, "%s.hash_layer_count" },
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -439,6 +451,23 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_KV, "blk.%d.attn_kv" },
{ LLM_TENSOR_ATTN_KV_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_OUT_A, "blk.%d.attn_output_a" },
{ LLM_TENSOR_ATTN_OUT_B, "blk.%d.attn_output_b" },
{ LLM_TENSOR_HC_HEAD_FN, "output_hc_fn" },
{ LLM_TENSOR_HC_HEAD_BASE, "output_hc_base" },
{ LLM_TENSOR_HC_HEAD_SCALE, "output_hc_scale" },
{ LLM_TENSOR_HC_ATTN_FN, "blk.%d.hc_attn_fn" },
{ LLM_TENSOR_HC_ATTN_BASE, "blk.%d.hc_attn_base" },
{ LLM_TENSOR_HC_ATTN_SCALE, "blk.%d.hc_attn_scale" },
{ LLM_TENSOR_HC_FFN_FN, "blk.%d.hc_ffn_fn" },
{ LLM_TENSOR_HC_FFN_BASE, "blk.%d.hc_ffn_base" },
{ LLM_TENSOR_HC_FFN_SCALE, "blk.%d.hc_ffn_scale" },
{ LLM_TENSOR_ATTN_COMPRESSOR_WKV, "blk.%d.attn_compressor_kv" },
{ LLM_TENSOR_ATTN_COMPRESSOR_WGATE, "blk.%d.attn_compressor_gate" },
{ LLM_TENSOR_ATTN_COMPRESSOR_APE, "blk.%d.attn_compressor_ape" },
{ LLM_TENSOR_ATTN_COMPRESSOR_NORM, "blk.%d.attn_compressor_norm" },
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
@@ -565,6 +594,11 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" },
{ LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" },
{ LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_WKV, "blk.%d.indexer_compressor_kv" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, "blk.%d.indexer_compressor_gate" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_APE, "blk.%d.indexer_compressor_ape" },
{ LLM_TENSOR_INDEXER_COMPRESSOR_NORM, "blk.%d.indexer_compressor_norm" },
{ LLM_TENSOR_FFN_GATE_TID2EID, "blk.%d.ffn_gate_tid2eid" },
{ LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" },
{ LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" },
{ LLM_TENSOR_FC, "fc" },
@@ -615,6 +649,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_KV_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_OUT_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_HEAD_FN, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_HEAD_BASE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_ADD}},
{LLM_TENSOR_HC_HEAD_SCALE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_HC_ATTN_FN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_ATTN_BASE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_HC_ATTN_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_HC_FFN_FN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_HC_FFN_BASE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_HC_FFN_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_COMPRESSOR_WKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_COMPRESSOR_WGATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_COMPRESSOR_APE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_ATTN_COMPRESSOR_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}},
@@ -778,6 +829,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_COMPRESSOR_WKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_COMPRESSOR_WGATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_COMPRESSOR_APE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
{LLM_TENSOR_INDEXER_COMPRESSOR_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_FFN_GATE_TID2EID, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
{LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
@@ -932,6 +988,7 @@ bool llm_arch_supports_sm_tensor(const llm_arch & arch) {
case LLM_ARCH_OLMOE:
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK32:
case LLM_ARCH_DEEPSEEK4:
case LLM_ARCH_GLM_DSA:
case LLM_ARCH_BITNET:
case LLM_ARCH_T5:
+34
View File
@@ -82,6 +82,7 @@ enum llm_arch {
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_DEEPSEEK2OCR,
LLM_ARCH_DEEPSEEK32,
LLM_ARCH_DEEPSEEK4,
LLM_ARCH_CHATGLM,
LLM_ARCH_GLM4,
LLM_ARCH_GLM4_MOE,
@@ -143,6 +144,7 @@ enum llm_arch {
LLM_ARCH_TALKIE,
LLM_ARCH_MELLUM,
LLM_ARCH_EAGLE3,
LLM_ARCH_DFLASH,
LLM_ARCH_UNKNOWN,
};
@@ -254,9 +256,19 @@ enum llm_kv {
LLM_KV_ATTENTION_INDEXER_HEAD_COUNT,
LLM_KV_ATTENTION_INDEXER_KEY_LENGTH,
LLM_KV_ATTENTION_INDEXER_TOP_K,
LLM_KV_ATTENTION_OUTPUT_GROUP_COUNT,
LLM_KV_ATTENTION_OUTPUT_LORA_RANK,
LLM_KV_ATTENTION_COMPRESS_ROPE_FREQ_BASE,
LLM_KV_ATTENTION_COMPRESS_RATIOS,
LLM_KV_ATTENTION_SHARED_KV_LAYERS,
LLM_KV_ATTENTION_RECURRENT_LAYERS,
LLM_KV_HYPER_CONNECTION_COUNT,
LLM_KV_HYPER_CONNECTION_SINKHORN_ITERATIONS,
LLM_KV_HYPER_CONNECTION_EPSILON,
LLM_KV_HASH_LAYER_COUNT,
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_COUNT_SWA,
LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -500,10 +512,27 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_B,
LLM_TENSOR_ATTN_KV_A_MQA,
LLM_TENSOR_ATTN_KV_B,
LLM_TENSOR_ATTN_KV,
LLM_TENSOR_ATTN_KV_NORM,
LLM_TENSOR_ATTN_OUT_A,
LLM_TENSOR_ATTN_OUT_B,
LLM_TENSOR_ATTN_K_B,
LLM_TENSOR_ATTN_V_B,
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_HC_HEAD_FN,
LLM_TENSOR_HC_HEAD_BASE,
LLM_TENSOR_HC_HEAD_SCALE,
LLM_TENSOR_HC_ATTN_FN,
LLM_TENSOR_HC_ATTN_BASE,
LLM_TENSOR_HC_ATTN_SCALE,
LLM_TENSOR_HC_FFN_FN,
LLM_TENSOR_HC_FFN_BASE,
LLM_TENSOR_HC_FFN_SCALE,
LLM_TENSOR_ATTN_COMPRESSOR_WKV,
LLM_TENSOR_ATTN_COMPRESSOR_WGATE,
LLM_TENSOR_ATTN_COMPRESSOR_APE,
LLM_TENSOR_ATTN_COMPRESSOR_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
LLM_TENSOR_FFN_SUB_NORM,
LLM_TENSOR_DEC_ATTN_NORM,
@@ -565,6 +594,11 @@ enum llm_tensor {
LLM_TENSOR_INDEXER_PROJ,
LLM_TENSOR_INDEXER_ATTN_K,
LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_INDEXER_COMPRESSOR_WKV,
LLM_TENSOR_INDEXER_COMPRESSOR_WGATE,
LLM_TENSOR_INDEXER_COMPRESSOR_APE,
LLM_TENSOR_INDEXER_COMPRESSOR_NORM,
LLM_TENSOR_FFN_GATE_TID2EID,
LLM_TENSOR_NEXTN_PROJ_PRE,
LLM_TENSOR_NEXTN_PROJ_POST,
LLM_TENSOR_NEXTN_EH_PROJ,
+8 -4
View File
@@ -100,10 +100,10 @@ llama_context::llama_context(
cparams.ctx_other = params.ctx_other;
}
if (model.arch == LLM_ARCH_EAGLE3) {
if (model.arch == LLM_ARCH_EAGLE3 || model.arch == LLM_ARCH_DFLASH) {
if (model.tok_embd == nullptr || model.output == nullptr) {
if (params.ctx_other == nullptr) {
throw std::runtime_error("EAGLE3 requires ctx_other to be set (this warning is normal during memory fitting)");
throw std::runtime_error(model.arch_name() + " requires ctx_other to be set (this warning is normal during memory fitting)");
}
cparams.ctx_other = params.ctx_other;
}
@@ -256,7 +256,7 @@ llama_context::llama_context(
LLAMA_LOG_INFO("%s: n_outputs_max = %u\n", __func__, cparams.n_outputs_max);
if (cparams.n_ctx_seq < hparams.n_ctx_train) {
LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
LLAMA_LOG_INFO("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n",
__func__, cparams.n_ctx_seq, hparams.n_ctx_train);
}
@@ -2321,7 +2321,11 @@ void llama_context::output_reorder() {
//
uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const {
if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) {
if (model.arch == LLM_ARCH_QWEN3NEXT ||
model.arch == LLM_ARCH_KIMI_LINEAR ||
model.arch == LLM_ARCH_QWEN35 ||
model.arch == LLM_ARCH_QWEN35MOE ||
model.arch == LLM_ARCH_DEEPSEEK4) {
return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors());
}
uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors());
+358 -24
View File
@@ -8,6 +8,7 @@
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-kv-cache-dsa.h"
#include "llama-kv-cache-dsv4.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
@@ -17,6 +18,7 @@
#include <cstring>
#include <numeric>
#include <sstream>
#include <string>
#include <unordered_set>
// dedup helpers
@@ -486,7 +488,11 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
// the mask is left unallocated when the graph only stores K/V without attending
// (e.g. DFlash's KV-injection pass)
if (self_kq_mask && self_kq_mask->buffer) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
if (self_k_rot) {
mctx->set_input_k_rot(self_k_rot);
@@ -564,7 +570,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
// base tensors may not be allocated if there are no non-SWA attention layers
if (self_k_idxs && self_k_idxs->buffer) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
if (self_v_idxs) {
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
}
}
// the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live
@@ -575,7 +583,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
// swa tensors may not be allocated if there are no SWA attention layers
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
if (self_v_idxs_swa) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
}
}
if (self_kq_mask_swa && self_kq_mask_swa->buffer) {
@@ -629,6 +639,283 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
return res;
}
static void dsv4_set_i64(ggml_tensor * dst, const std::vector<int64_t> & src) {
if (!dst || !dst->buffer) {
return;
}
GGML_ASSERT(dst->ne[0] == (int64_t) src.size());
ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst));
}
static void dsv4_set_i32(ggml_tensor * dst, const std::vector<int32_t> & src) {
if (!dst || !dst->buffer) {
return;
}
GGML_ASSERT(dst->ne[0] == (int64_t) src.size());
ggml_backend_tensor_set(dst, src.data(), 0, src.size()*ggml_element_size(dst));
}
static void dsv4_set_kq_mask(
ggml_tensor * dst,
const llama_kv_cache_dsv4_context::comp_plan & plan,
uint32_t n_tokens,
int64_t n_stream) {
if (!dst || !dst->buffer) {
return;
}
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(n_stream > 0);
GGML_ASSERT(n_tokens%n_stream == 0);
GGML_ASSERT(dst->ne[0] == plan.n_kv);
GGML_ASSERT(dst->ne[1] == (int64_t) n_tokens/n_stream);
GGML_ASSERT(dst->ne[2] == 1);
GGML_ASSERT(dst->ne[3] == n_stream);
GGML_ASSERT((int64_t) plan.n_visible.size() == (int64_t) n_tokens);
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
float * data = (float *) dst->data;
for (int64_t i = 0; i < (int64_t) n_tokens; ++i) {
const int32_t n_visible = plan.n_visible[i];
for (int64_t j = 0; j < dst->ne[0]; ++j) {
data[i*dst->ne[0] + j] = j < n_visible ? 0.0f : -INFINITY;
}
}
}
static ggml_tensor * dsv4_build_raw_kq_mask(
ggml_context * ctx,
const llama_kv_cache_dsv4_raw_context * mctx,
const llama_ubatch & ubatch,
const llama_cparams & cparams,
int64_t n_stream) {
const auto n_kv = mctx->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
GGML_ASSERT(n_stream > 0);
GGML_ASSERT(n_tokens%n_stream == 0);
const bool use_fattn = cparams.flash_attn && (!cparams.kv_unified || n_stream == 1);
const auto type = use_fattn ? GGML_TYPE_F16 : GGML_TYPE_F32;
ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_kq_mask");
return res;
}
static bool dsv4_can_reuse_raw_kq_mask(
ggml_tensor * kq_mask,
const llama_kv_cache_dsv4_raw_context * mctx,
const llama_ubatch & ubatch,
int64_t n_stream) {
const auto n_kv = mctx->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
GGML_ASSERT(n_stream > 0);
bool res = true;
res &= (kq_mask->ne[0] == n_kv);
res &= (kq_mask->ne[1] == n_tokens/n_stream);
res &= (kq_mask->ne[2] == 1);
res &= (kq_mask->ne[3] == n_stream);
return res;
}
static std::string dsv4_plan_positions(const std::vector<int32_t> & values) {
std::ostringstream ss;
ss << "[";
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) {
ss << ", ";
}
ss << values[i];
}
ss << "]";
return ss.str();
}
static bool dsv4_compress_debug() {
static const bool debug = []() {
const char * env = getenv("LLAMA_DSV4_COMPRESS_DEBUG");
return env && atoi(env) > 0;
}();
return debug;
}
static void dsv4_set_comp_inputs(
const llm_graph_input_dsv4::comp_input & inp,
const llama_kv_cache_dsv4_context::comp_plan & plan,
const char * name,
bool debug,
uint32_t n_tokens,
int64_t n_stream) {
dsv4_set_i32(inp.state_pos, plan.state_pos);
dsv4_set_i32(inp.state_persist_src_idxs, plan.state_persist_src_idxs);
dsv4_set_i32(inp.state_persist_dst_idxs, plan.state_persist_dst_idxs);
dsv4_set_i32(inp.state_read_idxs, plan.state_read_idxs);
dsv4_set_i64(inp.state_write_idxs, plan.state_write_idxs);
dsv4_set_i32(inp.state_write_pos, plan.state_write_pos);
dsv4_set_kq_mask(inp.kq_mask, plan, n_tokens, n_stream);
if (debug || dsv4_compress_debug()) {
LLAMA_LOG_INFO("%s: %s n_tokens=%u, n_stream=%d, state_persist_dst=%s, state_write_pos=%s\n",
__func__, name, n_tokens, (int) n_stream,
dsv4_plan_positions(plan.state_persist_dst_idxs).c_str(),
dsv4_plan_positions(plan.state_write_pos).c_str());
}
}
static bool dsv4_can_reuse_tensor_1d(ggml_tensor * t, int64_t ne0) {
return (t == nullptr && ne0 == 0) || (t != nullptr && t->ne[0] == ne0);
}
static bool dsv4_can_reuse_kq_mask(
ggml_tensor * t,
const llama_kv_cache_dsv4_context::comp_plan & plan,
uint32_t n_tokens,
int64_t n_stream) {
if (plan.n_kv == 0) {
return t == nullptr;
}
GGML_ASSERT(n_stream > 0);
return t != nullptr &&
t->ne[0] == plan.n_kv &&
t->ne[1] == (int64_t) n_tokens/n_stream &&
t->ne[2] == 1 &&
t->ne[3] == n_stream;
}
static bool dsv4_can_reuse_comp_input(
const llm_graph_input_dsv4::comp_input & inp,
const llama_kv_cache_dsv4_context::comp_plan & plan,
uint32_t n_tokens,
int64_t n_stream) {
bool res = true;
res &= dsv4_can_reuse_tensor_1d(inp.state_pos, plan.state_pos.size());
res &= dsv4_can_reuse_tensor_1d(inp.state_persist_src_idxs, plan.state_persist_src_idxs.size());
res &= dsv4_can_reuse_tensor_1d(inp.state_persist_dst_idxs, plan.state_persist_dst_idxs.size());
res &= dsv4_can_reuse_tensor_1d(inp.state_read_idxs, plan.state_read_idxs.size());
res &= dsv4_can_reuse_tensor_1d(inp.state_write_idxs, plan.state_write_idxs.size());
res &= dsv4_can_reuse_tensor_1d(inp.state_write_pos, plan.state_write_pos.size());
res &= dsv4_can_reuse_kq_mask(inp.kq_mask, plan, n_tokens, n_stream);
return res;
}
static ggml_tensor * dsv4_build_input_1d(
ggml_context * ctx,
ggml_type type,
int64_t ne0,
const std::string & name) {
if (ne0 == 0) {
return nullptr;
}
ggml_tensor * res = ggml_new_tensor_1d(ctx, type, ne0);
ggml_set_input(res);
ggml_set_name(res, name.c_str());
return res;
}
static void dsv4_build_comp_inputs(
ggml_context * ctx,
llm_graph_input_dsv4::comp_input & inp,
const llama_kv_cache_dsv4_context::comp_plan & plan,
const char * name,
int64_t n_stream) {
inp.state_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_pos.size(), std::string("dsv4_") + name + "_state_pos");
inp.state_persist_src_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_persist_src_idxs.size(), std::string("dsv4_") + name + "_state_persist_src_idxs");
inp.state_persist_dst_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_persist_dst_idxs.size(), std::string("dsv4_") + name + "_state_persist_dst_idxs");
inp.state_read_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_read_idxs.size(), std::string("dsv4_") + name + "_state_read_idxs");
inp.state_write_idxs = dsv4_build_input_1d(ctx, GGML_TYPE_I64, plan.state_write_idxs.size(), std::string("dsv4_") + name + "_state_write_idxs");
inp.state_write_pos = dsv4_build_input_1d(ctx, GGML_TYPE_I32, plan.state_write_pos.size(), std::string("dsv4_") + name + "_state_write_pos");
if (plan.n_kv > 0) {
const int64_t n_tokens = (int64_t) plan.n_visible.size();
GGML_ASSERT(n_stream > 0);
GGML_ASSERT(n_tokens%n_stream == 0);
inp.kq_mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, plan.n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp.kq_mask);
ggml_set_name(inp.kq_mask, (std::string("dsv4_") + name + "_kq_mask").c_str());
}
}
void llm_graph_input_dsv4_raw::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs && self_k_idxs->buffer) {
mctx->set_input_k_idxs(self_k_idxs);
}
if (self_kq_mask && self_kq_mask->buffer) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
if (self_k_rot) {
mctx->set_input_k_rot(self_k_rot);
}
}
void llm_graph_input_dsv4::set_input(const llama_ubatch * ubatch) {
const auto & plan_csa = mctx->get_csa_plan(*ubatch);
const auto & plan_hca = mctx->get_hca_plan(*ubatch);
const auto & plan_lid = mctx->get_lid_plan(*ubatch);
const int64_t n_stream = plan_csa.n_stream;
inp_raw->mctx = mctx->get_raw();
inp_raw->set_input(ubatch);
dsv4_set_comp_inputs(inp_csa, plan_csa, "csa", debug > 0, ubatch->n_tokens, n_stream);
dsv4_set_comp_inputs(inp_hca, plan_hca, "hca", debug > 0, ubatch->n_tokens, n_stream);
dsv4_set_comp_inputs(inp_lid, plan_lid, "lid", debug > 0, ubatch->n_tokens, n_stream);
if (inp_lid.k_rot && inp_lid.k_rot->buffer) {
mctx->get_lid()->set_input_k_rot(inp_lid.k_rot);
}
}
bool llm_graph_input_dsv4::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_dsv4_context *>(params.mctx);
this->mctx = mctx;
inp_raw->mctx = mctx->get_raw();
bool res = true;
const auto & plan_csa = mctx->get_csa_plan(params.ubatch);
const auto & plan_hca = mctx->get_hca_plan(params.ubatch);
const auto & plan_lid = mctx->get_lid_plan(params.ubatch);
const int64_t n_stream = plan_csa.n_stream;
const auto * raw_ctx = mctx->get_raw();
inp_raw->mctx = raw_ctx;
if (inp_raw->self_k_idxs && inp_raw->self_k_idxs->buffer) {
res &= inp_raw->self_k_idxs->ne[0] == raw_ctx->get_n_write();
}
if (inp_raw->self_kq_mask && inp_raw->self_kq_mask->buffer) {
res &= dsv4_can_reuse_raw_kq_mask(inp_raw->self_kq_mask, raw_ctx, params.ubatch, n_stream);
}
res &= dsv4_can_reuse_comp_input(inp_csa, plan_csa, params.ubatch.n_tokens, n_stream);
res &= dsv4_can_reuse_comp_input(inp_hca, plan_hca, params.ubatch.n_tokens, n_stream);
res &= dsv4_can_reuse_comp_input(inp_lid, plan_lid, params.ubatch.n_tokens, n_stream);
return res;
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
GGML_ASSERT(cross_kq_mask);
@@ -904,6 +1191,7 @@ void llm_graph_result::reset() {
t_logits = nullptr;
t_embd = nullptr;
t_embd_pooled = nullptr;
t_h_nextn = nullptr;
t_layer_inp.resize(LLAMA_MAX_LAYERS);
std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr);
@@ -1346,20 +1634,24 @@ ggml_tensor * llm_graph_context::build_ffn(
switch (type_op) {
case LLM_FFN_SILU:
if (gate && type_gate == LLM_FFN_PAR) {
// Step35: HF clamps gate (after SiLU) and up before multiplication
if (arch == LLM_ARCH_STEP35 && il >= 0) {
if (il >= 0) {
const float limit = hparams.swiglu_clamp_shexp[il];
constexpr float eps = 1e-6f;
if (limit > eps) {
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
cb(gate_act, "ffn_silu", il);
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
cb(gate_act, "ffn_silu_clamped", il);
tmp = ggml_clamp(ctx0, tmp, -limit, limit);
cb(tmp, "ffn_up_clamped", il);
cur = ggml_mul(ctx0, gate_act, tmp);
if (arch == LLM_ARCH_DEEPSEEK4) {
cur = ggml_clamp(ctx0, cur, -INFINITY, limit);
cb(cur, "ffn_gate_clamped", il);
cur = ggml_swiglu_split(ctx0, cur, tmp);
} else {
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
cb(gate_act, "ffn_silu", il);
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
cb(gate_act, "ffn_silu_clamped", il);
cur = ggml_mul(ctx0, gate_act, tmp);
}
cb(cur, "ffn_swiglu_limited", il);
type_gate = LLM_FFN_SEQ;
break;
@@ -1469,7 +1761,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * gate_up_exps,
ggml_tensor * up_exps_s,
ggml_tensor * gate_exps_s,
ggml_tensor * down_exps_s) const {
ggml_tensor * down_exps_s,
ggml_tensor * selected_experts_in) const {
return build_moe_ffn(
cur,
gate_inp, /* gate_inp_b */ nullptr,
@@ -1489,7 +1782,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
/* gate_up_exps_b */ nullptr,
up_exps_s,
gate_exps_s,
down_exps_s
down_exps_s,
selected_experts_in
);
}
@@ -1516,7 +1810,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * gate_up_exps_b,
ggml_tensor * up_exps_s,
ggml_tensor * gate_exps_s,
ggml_tensor * down_exps_s) const {
ggml_tensor * down_exps_s,
ggml_tensor * selected_experts_in) const {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -1525,6 +1820,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
if (probs_in == nullptr) {
logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS) {
ggml_mul_mat_set_prec(logits, GGML_PREC_F32);
}
cb(logits, "ffn_moe_logits", il);
} else {
logits = probs_in;
@@ -1549,6 +1847,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
{
probs = logits; // [n_expert, n_tokens]
} break;
case LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS:
{
probs = ggml_sqrt(ctx0, ggml_softplus(ctx0, logits)); // [n_expert, n_tokens]
} break;
default:
GGML_ABORT("fatal error");
}
@@ -1599,8 +1901,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
}
// select experts
ggml_tensor * selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
ggml_tensor * selected_experts = selected_experts_in;
if (selected_experts == nullptr) {
selected_experts = ggml_argsort_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
}
cb(selected_experts, "ffn_moe_topk", il);
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
@@ -1713,20 +2018,24 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
switch (type_op) {
case LLM_FFN_SILU:
if (gate_exps) {
// Step35: per-layer clamp for routed experts
if (arch == LLM_ARCH_STEP35 && il >= 0) {
if (il >= 0) {
const float limit = hparams.swiglu_clamp_exp[il];
constexpr float eps = 1e-6f;
if (limit > eps) {
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
cb(gate_act, "ffn_moe_silu", il);
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
cb(gate_act, "ffn_moe_silu_clamped", il);
up = ggml_clamp(ctx0, up, -limit, limit);
cb(up, "ffn_moe_up_clamped", il);
cur = ggml_mul(ctx0, gate_act, up);
if (arch == LLM_ARCH_DEEPSEEK4) {
cur = ggml_clamp(ctx0, cur, -INFINITY, limit);
cb(cur, "ffn_moe_gate_clamped", il);
cur = ggml_swiglu_split(ctx0, cur, up);
} else {
ggml_tensor * gate_act = ggml_silu(ctx0, cur);
cb(gate_act, "ffn_moe_silu", il);
gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit);
cb(gate_act, "ffn_moe_silu_clamped", il);
cur = ggml_mul(ctx0, gate_act, up);
}
cb(cur, "ffn_moe_swiglu_limited", il);
break;
}
@@ -2755,6 +3064,31 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}
llm_graph_input_dsv4 * llm_graph_context::build_inp_dsv4() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_dsv4_context *>(mctx);
const auto * raw_ctx = mctx_cur->get_raw();
auto inp_raw = std::make_unique<llm_graph_input_dsv4_raw>(cparams, raw_ctx);
const int64_t n_stream = mctx_cur->get_csa_plan(ubatch).n_stream;
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "DSV4 expects SWA raw cache");
inp_raw->self_k_idxs = raw_ctx->build_input_k_idxs(ctx0, ubatch);
inp_raw->self_kq_mask = dsv4_build_raw_kq_mask(ctx0, raw_ctx, ubatch, cparams, n_stream);
inp_raw->self_kq_mask_cnv = inp_raw->self_kq_mask;
inp_raw->self_k_rot = raw_ctx->build_input_k_rot(ctx0);
auto inp = std::make_unique<llm_graph_input_dsv4>(cparams, std::move(inp_raw), mctx_cur);
dsv4_build_comp_inputs(ctx0, inp->inp_csa, mctx_cur->get_csa_plan(ubatch), "csa", n_stream);
dsv4_build_comp_inputs(ctx0, inp->inp_hca, mctx_cur->get_hca_plan(ubatch), "hca", n_stream);
dsv4_build_comp_inputs(ctx0, inp->inp_lid, mctx_cur->get_lid_plan(ubatch), "lid", n_stream);
inp->inp_lid.k_rot = mctx_cur->get_lid()->build_input_k_rot(ctx0);
return (llm_graph_input_dsv4 *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * s,
ggml_tensor * state_copy_main,
+81 -2
View File
@@ -23,6 +23,8 @@ struct llama_memory_context_i;
class llama_kv_cache_context;
class llama_kv_cache_dsa_context;
class llama_kv_cache_dsv4_raw_context;
class llama_kv_cache_dsv4_context;
class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
@@ -459,6 +461,79 @@ public:
const llama_kv_cache_iswa_context * mctx;
};
// DSV4 raw graph inputs are SWA-only, but their mask may be stream-shaped
// so raw K can be concatenated with DSV4 compressed K in one attention op.
class llm_graph_input_dsv4_raw {
public:
llm_graph_input_dsv4_raw(
const llama_cparams & cparams,
const llama_kv_cache_dsv4_raw_context * mctx) :
cparams(cparams),
mctx(mctx) {
}
void set_input(const llama_ubatch * ubatch);
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_k_rot = nullptr;
const llama_cparams cparams;
const llama_kv_cache_dsv4_raw_context * mctx;
};
class llm_graph_input_dsv4 : public llm_graph_input_i {
public:
struct comp_input {
ggml_tensor * state_pos = nullptr; // I32 [n_state]
ggml_tensor * state_persist_src_idxs = nullptr; // I32 [n_state_persist]
ggml_tensor * state_persist_dst_idxs = nullptr; // I32 [n_state_persist]
ggml_tensor * state_read_idxs = nullptr; // I32 [ratio*n_state_write]
ggml_tensor * state_write_idxs = nullptr; // I64 [n_state_write]
ggml_tensor * state_write_pos = nullptr; // I32 [n_state_write]
ggml_tensor * kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * k_rot = nullptr;
};
llm_graph_input_dsv4(
const llama_cparams & cparams,
std::unique_ptr<llm_graph_input_dsv4_raw> inp_raw,
const llama_kv_cache_dsv4_context * mctx) :
inp_raw(std::move(inp_raw)),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_dsv4() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
llm_graph_input_dsv4_raw * get_raw() const { return inp_raw.get(); }
const comp_input & get_csa() const { return inp_csa; }
const comp_input & get_hca() const { return inp_hca; }
const comp_input & get_lid() const { return inp_lid; }
std::unique_ptr<llm_graph_input_dsv4_raw> inp_raw;
comp_input inp_csa;
comp_input inp_hca;
comp_input inp_lid;
const llama_cparams cparams;
const llama_kv_cache_dsv4_context * mctx;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
public:
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@@ -920,7 +995,8 @@ struct llm_graph_context {
ggml_tensor * gate_up_exps = nullptr,
ggml_tensor * up_exps_s = nullptr,
ggml_tensor * gate_exps_s = nullptr,
ggml_tensor * down_exps_s = nullptr) const;
ggml_tensor * down_exps_s = nullptr,
ggml_tensor * selected_experts_in = nullptr) const;
ggml_tensor * build_moe_ffn(
ggml_tensor * cur,
@@ -945,7 +1021,8 @@ struct llm_graph_context {
ggml_tensor * gate_up_exps_b = nullptr,
ggml_tensor * up_exps_s = nullptr,
ggml_tensor * gate_exps_s = nullptr,
ggml_tensor * down_exps_s = nullptr) const;
ggml_tensor * down_exps_s = nullptr,
ggml_tensor * selected_experts_in = nullptr) const;
//
// inputs
@@ -1045,6 +1122,8 @@ struct llm_graph_context {
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
llm_graph_input_dsv4 * build_inp_dsv4() const;
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_iswa * inp,
+11
View File
@@ -14,6 +14,7 @@ enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3, // applied to the router weights instead of the logits
LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS = 4,
};
enum llama_swa_type {
@@ -226,6 +227,16 @@ struct llama_hparams {
uint32_t indexer_head_size = 0;
uint32_t indexer_top_k = 0;
// DeepSeek-V4
uint32_t dsv4_o_group_count = 0;
uint32_t dsv4_o_lora_rank = 0;
uint32_t dsv4_hc_mult = 0;
uint32_t dsv4_hc_sinkhorn_iters = 0;
uint32_t dsv4_hash_layer_count = 0;
float dsv4_compress_rope_base = 0.0f;
float dsv4_hc_eps = 0.0f;
std::array<uint32_t, LLAMA_MAX_LAYERS> dsv4_compress_ratios;
// qwen3vl deepstack
// When parsed from GGUF, this implies the first N layers consume the first
// N deepstack embeddings. Use deepstack_mapping_arr if you need a more
File diff suppressed because it is too large Load Diff
+362
View File
@@ -0,0 +1,362 @@
#pragma once
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include <map>
#include <memory>
#include <unordered_map>
#include <vector>
class llama_dsv4_comp_state {
public:
llama_dsv4_comp_state(
const llama_model & model,
bool offload,
bool unified,
uint32_t n_seq_max,
uint32_t ratio,
uint32_t state_size,
uint32_t n_embd_state,
const char * name,
const llama_memory_i::layer_filter_cb & filter);
void clear(bool data);
uint32_t get_ratio() const;
uint32_t get_state_size() const;
uint32_t get_n_stream() const;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const;
void state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags);
ggml_tensor * get_kv (ggml_context * ctx, int32_t il) const;
ggml_tensor * get_score(ggml_context * ctx, int32_t il) const;
ggml_tensor * cpy_kv (ggml_context * ctx, ggml_tensor * cur, ggml_tensor * idxs, int32_t il) const;
ggml_tensor * cpy_score(ggml_context * ctx, ggml_tensor * cur, ggml_tensor * idxs, int32_t il) const;
private:
struct layer {
uint32_t il;
ggml_tensor * kv;
ggml_tensor * score;
};
const uint32_t ratio;
const uint32_t state_size;
const uint32_t n_embd_state;
const uint32_t n_stream;
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
std::vector<layer> layers;
std::unordered_map<int32_t, int32_t> map_layer_ids;
size_t total_size() const;
};
//
// llama_kv_cache_dsv4
//
// DSV4 uses a normal raw/SWA token cache plus compressed K-only block caches.
// The compressed caches are storage only; DSV4-specific visibility and block
// planning are handled by llama_kv_cache_dsv4_context / llm_graph_input_dsv4.
class llama_kv_cache_dsv4 : public llama_memory_i {
public:
llama_kv_cache_dsv4(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache_dsv4() = default;
//
// llama_memory_i
//
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_dsv4 specific API
//
llama_kv_cache_iswa * get_raw() const;
llama_kv_cache * get_csa() const;
llama_kv_cache * get_hca() const;
llama_kv_cache * get_lid() const;
llama_dsv4_comp_state * get_csa_state() const;
llama_dsv4_comp_state * get_hca_state() const;
llama_dsv4_comp_state * get_lid_state() const;
private:
llama_hparams hparams_raw;
llama_hparams hparams_csa;
llama_hparams hparams_hca;
llama_hparams hparams_lid;
const uint32_t n_seq_max;
std::unique_ptr<llama_kv_cache_iswa> kv_raw;
std::unique_ptr<llama_kv_cache> kv_csa;
std::unique_ptr<llama_kv_cache> kv_hca;
std::unique_ptr<llama_kv_cache> kv_lid;
std::unique_ptr<llama_dsv4_comp_state> csa_state;
std::unique_ptr<llama_dsv4_comp_state> hca_state;
std::unique_ptr<llama_dsv4_comp_state> lid_state;
void clear_compressed(bool data);
};
// DSV4 raw attention only uses the SWA half of kv_raw. The base half is kept
// for generic ISWA bookkeeping, but it has no DSV4 layers to expose here.
class llama_kv_cache_dsv4_raw_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
llama_kv_cache_dsv4_raw_context(llama_kv_cache_iswa * kv);
llama_kv_cache_dsv4_raw_context(
llama_kv_cache_iswa * kv,
llama_context * lctx,
bool optimize);
llama_kv_cache_dsv4_raw_context(
llama_kv_cache_iswa * kv,
slot_info_vec_t sinfos_base_write,
slot_info_vec_t sinfos_swa_write,
slot_info_vec_t sinfos_swa_read,
std::vector<llama_ubatch> ubatches,
std::vector<llama_ubatch> ubatches_write);
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
uint32_t get_n_kv() const;
uint32_t get_n_write() const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
void set_input_k_idxs(ggml_tensor * dst) const;
void set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_k_rot(ggml_tensor * dst) const;
private:
size_t i_next = 0;
llama_kv_cache * kv_swa = nullptr;
slot_info_vec_t sinfos_write;
slot_info_vec_t sinfos_read;
std::vector<llama_ubatch> ubatches;
std::vector<llama_ubatch> ubatches_write;
const llama_memory_context_ptr ctx_base_mem;
const llama_memory_context_ptr ctx_swa_mem;
uint32_t n_kv = 0;
const llama_memory_status status;
};
// DSV4 compressed KV rows are graph outputs, not normal token KV writes.
// Keep a small context that exposes K tensors without generic apply() semantics.
class llama_kv_cache_dsv4_comp_context {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
llama_kv_cache_dsv4_comp_context(llama_kv_cache * kv);
llama_kv_cache_dsv4_comp_context(
llama_kv_cache * kv,
slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches);
bool next();
uint32_t get_n_kv() const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
void set_input_k_rot(ggml_tensor * dst) const;
private:
llama_kv_cache * kv;
size_t i_cur = 0;
slot_info_vec_t sinfos;
std::vector<llama_ubatch> ubatches;
uint32_t n_kv;
};
class llama_kv_cache_dsv4_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
struct comp_plan {
// Per-ubatch recipe for updating compressor state, committing completed
// compressed rows, and masking the compressed attention source.
// APE row ids, i.e. pos % ratio, for the compressor-state updates.
std::vector<int32_t> state_pos;
// Current-ubatch source row ids and unique persistent-state
// destination row ids for deterministic ring-state updates.
std::vector<int32_t> state_persist_src_idxs;
std::vector<int32_t> state_persist_dst_idxs;
// Flattened source row ids used for state-backed commits. Source rows
// index the graph-local [persistent_state | current_ubatch_scratch]
// tensor. For overlapped compression the first half is previous rows
// and the second half is current rows; a final synthetic zero/-inf row
// may be addressed for the first block's previous half.
std::vector<int32_t> state_read_idxs;
// Final compressed-cache row ids written by state-backed commits.
// A non-boundary CSA/LID decode step can target a masked scratch row.
std::vector<int64_t> state_write_idxs;
// RoPE positions for state-backed commits.
std::vector<int32_t> state_write_pos;
// Number of completed compressed rows visible for each query token.
std::vector<int32_t> n_visible;
// Number of streams used by the attention graph for this ubatch.
int64_t n_stream = 1;
// Graph-width for compressed rows. This can be larger than n_visible
// so masked padding rows do not force a new graph at every CSA block.
int64_t n_kv = 0;
};
llama_kv_cache_dsv4_context(llama_memory_status status);
llama_kv_cache_dsv4_context(
llama_kv_cache_dsv4 * kv);
llama_kv_cache_dsv4_context(
llama_kv_cache_dsv4 * kv,
llama_context * lctx,
bool optimize);
llama_kv_cache_dsv4_context(
llama_kv_cache_dsv4 * kv,
slot_info_vec_t sinfos_raw_base_write,
slot_info_vec_t sinfos_raw_swa_write,
slot_info_vec_t sinfos_raw_swa_read,
std::vector<llama_ubatch> ubatches,
std::vector<llama_ubatch> ubatches_raw);
virtual ~llama_kv_cache_dsv4_context();
//
// llama_memory_context_i
//
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_dsv4_context specific API
//
const llama_kv_cache_dsv4_raw_context * get_raw() const;
const llama_kv_cache_dsv4_comp_context * get_csa() const;
const llama_kv_cache_dsv4_comp_context * get_hca() const;
const llama_kv_cache_dsv4_comp_context * get_lid() const;
const llama_dsv4_comp_state * get_csa_state() const;
const llama_dsv4_comp_state * get_hca_state() const;
const llama_dsv4_comp_state * get_lid_state() const;
const comp_plan & get_csa_plan() const;
const comp_plan & get_hca_plan() const;
const comp_plan & get_lid_plan() const;
const comp_plan & get_csa_plan(const llama_ubatch & ubatch) const;
const comp_plan & get_hca_plan(const llama_ubatch & ubatch) const;
const comp_plan & get_lid_plan(const llama_ubatch & ubatch) const;
private:
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
std::vector<comp_plan> plans_csa;
std::vector<comp_plan> plans_hca;
std::vector<comp_plan> plans_lid;
const std::unique_ptr<llama_kv_cache_dsv4_raw_context> ctx_raw;
const llama_memory_context_ptr ctx_csa_mem;
const llama_memory_context_ptr ctx_hca_mem;
const llama_memory_context_ptr ctx_lid_mem;
const std::unique_ptr<llama_kv_cache_dsv4_comp_context> ctx_csa;
const std::unique_ptr<llama_kv_cache_dsv4_comp_context> ctx_hca;
const std::unique_ptr<llama_kv_cache_dsv4_comp_context> ctx_lid;
const llama_dsv4_comp_state * csa_state = nullptr;
const llama_dsv4_comp_state * hca_state = nullptr;
const llama_dsv4_comp_state * lid_state = nullptr;
bool reserve_plans = false;
mutable comp_plan reserve_plan_csa;
mutable comp_plan reserve_plan_hca;
mutable comp_plan reserve_plan_lid;
const llama_memory_status status;
};
+22 -1
View File
@@ -26,7 +26,28 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse,
const layer_share_cb & share) : hparams(model.hparams), unified(unified) {
const layer_share_cb & share) :
llama_kv_cache_iswa(model, model.hparams, type_k, type_v, v_trans, offload, swa_full, unified,
kv_size, n_seq_max, n_ubatch, n_pad, mem_other, filter, reuse, share) {
}
llama_kv_cache_iswa::llama_kv_cache_iswa(
const llama_model & model,
const llama_hparams & hparams,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse,
const layer_share_cb & share) : unified(unified) {
// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
+18 -2
View File
@@ -30,6 +30,24 @@ public:
const layer_reuse_cb & reuse,
const layer_share_cb & share);
llama_kv_cache_iswa(
const llama_model & model,
const llama_hparams & hparams,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool swa_full,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad,
llama_memory_t mem_other,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse,
const layer_share_cb & share);
~llama_kv_cache_iswa() = default;
//
@@ -73,8 +91,6 @@ public:
llama_kv_cache * get_swa () const;
private:
const llama_hparams & hparams;
const bool unified;
std::unique_ptr<llama_kv_cache> kv_base;
+26 -6
View File
@@ -211,10 +211,12 @@ llama_kv_cache::llama_kv_cache(
n_embd_head_k_all = -1;
}
if (n_embd_head_v_all == 0) {
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
n_embd_head_v_all = -1;
if (!is_mla) {
if (n_embd_head_v_all == 0) {
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
n_embd_head_v_all = -1;
}
}
// [TAG_V_CACHE_VARIABLE]
@@ -336,8 +338,9 @@ llama_kv_cache::llama_kv_cache(
ggml_is_quantized(type_k) &&
hparams.n_embd_head_k() % 64 == 0;
// always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer
if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) {
// always create Hadamard rotation tensors for DeepSeek lightning indexers
if ((model.arch == LLM_ARCH_DEEPSEEK32 || model.arch == LLM_ARCH_DEEPSEEK4) &&
hparams.n_embd_head_k_full == hparams.indexer_head_size) {
attn_rot_k = true;
}
@@ -1220,6 +1223,23 @@ ggml_type llama_kv_cache::type_v() const {
return layers[0].v->type;
}
std::vector<uint32_t> llama_kv_cache::get_layer_ids() const {
std::vector<uint32_t> res;
res.reserve(layers.size());
for (const auto & layer : layers) {
res.push_back(layer.il);
}
return res;
}
ggml_tensor * llama_kv_cache::get_k_storage(int32_t il) const {
const int32_t ikv = map_layer_ids.at(il);
return layers[ikv].k;
}
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
+3
View File
@@ -161,6 +161,9 @@ public:
ggml_type type_k() const;
ggml_type type_v() const;
std::vector<uint32_t> get_layer_ids() const;
ggml_tensor * get_k_storage(int32_t il) const;
//
// graph_build API
//
+3
View File
@@ -294,6 +294,8 @@ namespace GGUFMeta {
}
template bool llama_model_loader::get_arr_n(enum llm_kv kid, uint32_t & result, bool required);
template std::enable_if<std::is_integral<uint32_t>::value, bool>::type
llama_model_loader::get_arr_n<uint32_t>(const std::string & key, uint32_t & result, bool required);
template<typename T>
bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) {
@@ -395,6 +397,7 @@ namespace GGUFMeta {
template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required);
template bool llama_model_loader::get_arr<std::array<int32_t, 512>>(enum llm_kv kid, std::array<int32_t, 512> & result, bool required);
template bool llama_model_loader::get_arr<std::vector<int32_t>>(enum llm_kv kid, std::vector<int32_t> & result, bool required);
template bool llama_model_loader::get_arr<std::array<uint32_t, LLAMA_MAX_LAYERS>>(enum llm_kv kid, std::array<uint32_t, LLAMA_MAX_LAYERS> & result, bool required);
template<typename T>
bool llama_model_loader::get_key(const std::string & key, T & result, bool required) {
+33 -2
View File
@@ -11,6 +11,7 @@
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-kv-cache-dsa.h"
#include "llama-kv-cache-dsv4.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
@@ -181,6 +182,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_deepseek2ocr(params);
case LLM_ARCH_DEEPSEEK32:
return new llama_model_deepseek32(params);
case LLM_ARCH_DEEPSEEK4:
return new llama_model_deepseek4(params);
case LLM_ARCH_GLM_DSA:
return new llama_model_glm_dsa(params);
case LLM_ARCH_MISTRAL4:
@@ -291,6 +294,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_mistral3(params);
case LLM_ARCH_EAGLE3:
return new llama_model_eagle3(params);
case LLM_ARCH_DFLASH:
return new llama_model_dflash(params);
case LLM_ARCH_MIMO2:
return new llama_model_mimo2(params);
case LLM_ARCH_KIMI_LINEAR:
@@ -815,6 +820,7 @@ static const char * llama_expert_gating_func_name(llama_expert_gating_func_type
switch (type) {
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX: return "softmax";
case LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID: return "sigmoid";
case LLAMA_EXPERT_GATING_FUNC_TYPE_SQRT_SOFTPLUS: return "sqrtsoftplus";
default: return "unknown";
}
}
@@ -2154,7 +2160,24 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
}
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
if (arch == LLM_ARCH_DEEPSEEK4) {
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE);
res = new llama_kv_cache_dsv4(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
params.swa_full,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
cparams.n_ubatch,
1,
filter,
reuse);
} else if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
if (arch == LLM_ARCH_GEMMA4_ASSISTANT) {
@@ -2326,6 +2349,11 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
}
int32_t llama_model_n_swa(const llama_model * model) {
// dsv4 kv-cache has SWA but it cannot be used as a rollback because of
// other compression ratios, so we return 0 here
if (model->arch == LLM_ARCH_DEEPSEEK4) {
return 0;
}
return model->hparams.n_swa;
}
@@ -2407,6 +2435,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_DEEPSEEK2OCR:
case LLM_ARCH_DEEPSEEK32:
case LLM_ARCH_DEEPSEEK4:
case LLM_ARCH_PLM:
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
@@ -2494,6 +2523,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_STEP35:
case LLM_ARCH_TALKIE:
case LLM_ARCH_MELLUM:
case LLM_ARCH_DFLASH:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
@@ -2617,7 +2647,8 @@ bool llama_model_has_encoder(const llama_model * model) {
switch (model->arch) {
case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER:
case LLM_ARCH_EAGLE3: return true;
case LLM_ARCH_EAGLE3:
case LLM_ARCH_DFLASH: return true;
default: return false;
}
}
+25
View File
@@ -255,9 +255,11 @@ struct llama_layer {
struct ggml_tensor * wq_b = nullptr;
struct ggml_tensor * wkv_a_mqa = nullptr;
struct ggml_tensor * wkv_b = nullptr;
struct ggml_tensor * wkv = nullptr;
struct ggml_tensor * wk_b = nullptr;
struct ggml_tensor * wv_b = nullptr;
struct ggml_tensor * wqkv_b = nullptr;
struct ggml_tensor * wo_a = nullptr;
struct ggml_tensor * wo_b = nullptr;
struct ggml_tensor * wq_cross = nullptr;
struct ggml_tensor * wk_cross = nullptr;
@@ -333,6 +335,7 @@ struct llama_layer {
struct ggml_tensor * ffn_up_b = nullptr; // b3
struct ggml_tensor * ffn_act = nullptr;
struct ggml_tensor * ffn_exp_probs_b = nullptr;
struct ggml_tensor * ffn_gate_tid2eid = nullptr;
// mamba proj
struct ggml_tensor * ssm_in = nullptr;
@@ -463,6 +466,23 @@ struct llama_layer {
// openai-moe
struct ggml_tensor * attn_sinks = nullptr;
// DeepSeek-V4
struct ggml_tensor * attn_kv_norm = nullptr;
struct ggml_tensor * hc_attn_fn = nullptr;
struct ggml_tensor * hc_attn_base = nullptr;
struct ggml_tensor * hc_attn_scale = nullptr;
struct ggml_tensor * hc_ffn_fn = nullptr;
struct ggml_tensor * hc_ffn_base = nullptr;
struct ggml_tensor * hc_ffn_scale = nullptr;
struct ggml_tensor * attn_comp_wkv = nullptr;
struct ggml_tensor * attn_comp_wgate = nullptr;
struct ggml_tensor * attn_comp_ape = nullptr;
struct ggml_tensor * attn_comp_norm = nullptr;
struct ggml_tensor * indexer_comp_wkv = nullptr;
struct ggml_tensor * indexer_comp_wgate = nullptr;
struct ggml_tensor * indexer_comp_ape = nullptr;
struct ggml_tensor * indexer_comp_norm = nullptr;
// cogvlm
struct ggml_tensor * visexp_attn_wqkv = nullptr;
struct ggml_tensor * visexp_attn_wo = nullptr;
@@ -553,6 +573,11 @@ struct llama_model {
struct ggml_tensor * nextn_proj_pre = nullptr;
struct ggml_tensor * nextn_proj_post = nullptr;
// DeepSeek-V4
struct ggml_tensor * hc_head_fn = nullptr;
struct ggml_tensor * hc_head_base = nullptr;
struct ggml_tensor * hc_head_scale = nullptr;
// classifier
struct ggml_tensor * cls = nullptr;
struct ggml_tensor * cls_b = nullptr;
File diff suppressed because it is too large Load Diff
+276
View File
@@ -0,0 +1,276 @@
#include "models.h"
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
void llama_model_dflash::load_arch_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
if (!ml.get_arr(LLM_KV_TARGET_LAYERS, target_layer_ids, false)) {
throw std::runtime_error("DFlash model requires 'target_layers' in GGUF metadata");
}
hparams.n_embd_inp_enc_impl = (uint32_t) target_layer_ids.size() * hparams.n_embd;
LLAMA_LOG_INFO("%s: DFlash extract_layers = [", __func__);
for (size_t i = 0; i < target_layer_ids.size(); ++i) {
LLAMA_LOG_INFO("%d%s", target_layer_ids[i], i + 1 < target_layer_ids.size() ? ", " : "");
}
LLAMA_LOG_INFO("]\n");
// optional interleaved sliding-window attention with per-layer pattern array.
// DFlash has a single rope, so the SWA rope == main rope.
if (ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false) && hparams.n_swa > 0) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer());
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
}
type = LLM_TYPE_UNKNOWN;
}
void llama_model_dflash::load_arch_tensors(llama_model_loader &) {
LLAMA_LOAD_LOCALS;
const int64_t n_embd_inp = hparams.n_embd_inp_enc();
fc = create_tensor(tn(LLM_TENSOR_FC, "weight"), { n_embd_inp, n_embd }, 0);
output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), { n_embd }, 0); // encoder hidden_norm (after fc)
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); // decoder final norm
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
}
}
std::unique_ptr<llm_graph_context> llama_model_dflash::build_arch_graph(const llm_graph_params & params) const {
switch (params.gtype) {
case LLM_GRAPH_TYPE_ENCODER:
return std::make_unique<graph<true>>(*this, params);
case LLM_GRAPH_TYPE_DEFAULT:
case LLM_GRAPH_TYPE_DECODER:
return std::make_unique<graph<false>>(*this, params);
default:
GGML_ABORT("invalid graph type");
};
}
template <>
ggml_tensor * llama_model_dflash::graph<true>::build_inp_embd_enc() const {
auto inp_target = std::make_unique<llm_graph_input_embd>(hparams.n_embd_inp_enc());
inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp_enc(), n_tokens);
ggml_set_input(inp_target->embd);
ggml_tensor * cur = inp_target->embd;
cb(cur, "inp_embd", -1);
res->add_input(std::move(inp_target));
return cur;
}
// DFlash Encoder: processes target model features through feature fusion layer
template <>
llama_model_dflash::graph<true>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
ggml_tensor * cur = build_inp_embd_enc();
cur = build_lora_mm(model.fc, cur);
cb(cur, "fc_out", -1);
cur = build_norm(cur, model.output_norm_enc, NULL, LLM_NORM_RMS, -1);
cb(cur, "enc_norm_out", -1);
ggml_set_output(cur);
res->t_h_nextn = cur;
ggml_build_forward_expand(gf, cur);
}
// DFlash decoder, dual-mode by batch type:
// * embd batch -> fused target features: project + inject K/V into the cache.
// * token batch -> noise-block diffusion: attend over [committed, MASK...] to generate draft tokens
template <>
llama_model_dflash::graph<false>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k());
ggml_tensor * inp_pos = build_inp_pos();
// optional iSWA: pick the matching attention input
const bool use_iswa = hparams.swa_type != LLAMA_SWA_TYPE_NONE;
llm_graph_input_attn_kv * inp_attn = nullptr;
llm_graph_input_attn_kv_iswa * inp_attn_iswa = nullptr;
if (use_iswa) {
inp_attn_iswa = build_attn_inp_kv_iswa();
} else {
inp_attn = build_attn_inp_kv();
}
const float kq_scale = 1.0f/sqrtf(float(n_embd_head));
// KV cache injection
if (ubatch.embd) {
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
ggml_set_input(inp->embd);
ggml_tensor * inp_g = inp->embd;
cb(inp_g, "inp_g_embeddings", -1);
res->add_input(std::move(inp));
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers[il];
ggml_tensor * Kcur = build_lora_mm(layer.wk, inp_g);
ggml_tensor * Vcur = build_lora_mm(layer.wv, inp_g);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Kcur = build_norm(Kcur, layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur_injected", il);
cb(Vcur, "Vcur_injected", il);
if (use_iswa) {
// route each layer's K/V to its sub-cache: SWA layers -> sliding cache, full -> dense
const bool is_swa = hparams.is_swa(il);
const auto * kv = is_swa ? inp_attn_iswa->mctx->get_swa() : inp_attn_iswa->mctx->get_base();
ggml_tensor * k_idxs = is_swa ? inp_attn_iswa->get_k_idxs_swa() : inp_attn_iswa->get_k_idxs();
ggml_tensor * v_idxs = is_swa ? inp_attn_iswa->get_v_idxs_swa() : inp_attn_iswa->get_v_idxs();
ggml_build_forward_expand(gf, kv->cpy_k(ctx0, Kcur, k_idxs, il));
ggml_build_forward_expand(gf, kv->cpy_v(ctx0, Vcur, v_idxs, il));
} else {
ggml_build_forward_expand(gf, inp_attn->mctx->cpy_k(ctx0, Kcur, inp_attn->get_k_idxs(), il));
ggml_build_forward_expand(gf, inp_attn->mctx->cpy_v(ctx0, Vcur, inp_attn->get_v_idxs(), il));
}
}
res->t_embd = inp_g;
ggml_build_forward_expand(gf, inp_g);
return;
}
// tok_embd from the target model (shared via ctx_other)
auto * tok_embd = model.tok_embd;
if (tok_embd == nullptr) {
GGML_ASSERT(cparams.ctx_other != nullptr);
const auto * model_other = llama_get_model(cparams.ctx_other);
GGML_ASSERT(model_other->tok_embd != nullptr && "DFlash decoder requires the target model's token embeddings");
tok_embd = model_other->tok_embd;
}
auto inp = std::make_unique<llm_graph_input_embd>(n_embd);
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_set_input(inp->tokens);
ggml_tensor * inpL = ggml_get_rows(ctx0, tok_embd, inp->tokens);
cb(inpL, "inp_noise_embd", -1);
res->add_input(std::move(inp));
for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers[il];
ggml_tensor * noise_norm = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
cb(noise_norm, "noise_norm", il);
ggml_tensor * Qcur = build_lora_mm(layer.wq, noise_norm);
ggml_tensor * Kcur = build_lora_mm(layer.wk, noise_norm);
ggml_tensor * Vcur = build_lora_mm(layer.wv, noise_norm);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = build_norm(Qcur, layer.attn_q_norm, NULL, LLM_NORM_RMS, il);
Kcur = build_norm(Kcur, layer.attn_k_norm, NULL, LLM_NORM_RMS, il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
// cache-aware, non-causal attention
ggml_tensor * cur = use_iswa
? build_attn(inp_attn_iswa, layer.wo, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il)
: build_attn(inp_attn, layer.wo, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
layer.ffn_up, NULL, NULL,
layer.ffn_gate, NULL, NULL,
layer.ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il);
inpL = cur;
}
ggml_tensor * cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head from the target model (shared via ctx_other)
auto * output = model.output;
if (output == nullptr) {
GGML_ASSERT(cparams.ctx_other != nullptr);
const auto * model_other = llama_get_model(cparams.ctx_other);
GGML_ASSERT(model_other->output != nullptr && "DFlash decoder requires the target model's output projection");
output = model_other->output;
}
cur = build_lora_mm(output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
-1
View File
@@ -169,7 +169,6 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp,
GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
GGML_ASSERT(d_inner % n_head == 0);
GGML_ASSERT(d_inner % d_state == 0);
GGML_ASSERT(d_inner % n_group == 0);
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+7 -6
View File
@@ -39,10 +39,11 @@ void llama_model_mamba2::load_arch_tensors(llama_model_loader &) {
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_group = hparams.ssm_n_group;
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head;
const int64_t dt_rank = hparams.ssm_dt_rank;
const int64_t conv_dim = d_inner + 2 * n_group * d_state;
const int64_t d_in_proj = d_inner + conv_dim + dt_rank;
// only an expansion factor of 2 is supported for now
GGML_ASSERT(2 * n_embd == d_inner);
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -68,11 +69,11 @@ void llama_model_mamba2::load_arch_tensors(llama_model_loader &) {
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0);
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0);
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {dt_rank}, 0);
// no "weight" suffix for these
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, dt_rank}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, dt_rank}, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);

Some files were not shown because too many files have changed in this diff Show More