Compare commits

..

76 Commits

Author SHA1 Message Date
Hongqiang Wang ebd048fc5e opencl: flash attention improvement (#25069)
* opencl: rework FA kernel for f16 and f32

* opencl: flash-attention prefill prepass kernels

- flash_attn_kv_pad_f16    pads the tail KV tile to a BLOCK_N multiple
- flash_attn_mask_pad_f16  pads the matching mask tile
- flash_attn_blk_f16       classifies each KV tile per query block as
                           fully masked / mixed / fully unmasked, so
                           the main kernel can skip fully-masked tiles
                           and the mask lookup for fully-unmasked ones

* opencl: FA kernels for q4_0 and q8_0

* opencl: `set_rows` for f32 to q8_0/q4_0

* opencl: dequant kernels for q4_0 and q8_0

* opencl: add FA tile tuning table with override

* opencl: wire host side for FA

* opencl: q4_0 MoE tensors are also SOA'ed

* opencl: cosmetic fix

* opencl: refactor, also clarify some code paths in comments

* opencl: fix inifity for `-cl-finite-math-only`

---------

Co-authored-by: Li He <lih@qti.qualcomm.com>
2026-06-27 15:36:06 -07:00
Gaurav Garg 0ed235ea2c [CUDA] Added a cudaMemcpy2DAsync fast path to ggml_cuda_cpy (#25057)
* [CUDA] Added a cudaMemcpy2DAsync fast path to ggml_cuda_cpy

Add a CUDA ggml_cpy fast path for same-type, same-shape strided copies that are just 2D pitched block copies.
When tensors are not fully contiguous but each row is contiguous, it now uses cudaMemcpy2DAsync instead of the slow element-wise scalar copy kernel.

This fixes the GDN recurrent snapshot update with -np 4, where rollback slots are separated by cache stride gaps.

* Add new tests that execute the new optimized strided copy path

* Return unsupported for strided copy in OpenVINO, as new tests are failing
2026-06-27 17:46:21 +05:30
Neo Zhang 9bebfcb4bc sycl : fix failed ut cases of norm (#25044) 2026-06-27 12:13:43 +03:00
Ruben Ortlam 0b6529d818 vulkan: fix step operator for 0 input (#25036) 2026-06-27 10:57:31 +02:00
Christian Kastner c299a92c38 binaries : Improve rpc-server and export-graph-ops names. (#25045)
Tests are generally prefixed with -test, so rename export-graph-ops
accordingly.

rpc-server is probably too generic a name for /usr/bin. Because it
should work with any ggml application, it is renamed to ggml-rpc-server.
2026-06-27 10:31:29 +03:00
Sigbjørn Skjæret 0275c0f800 ci : add windows-openvino to check-release (#25022) 2026-06-27 10:30:56 +03:00
Sigbjørn Skjæret 83d385b429 tests : fix test-chat-template --no-common option (#25075) 2026-06-27 10:30:19 +03:00
Adrien Gallouët 050ee92d04 app : allow --version, --licenses & --help (#25054)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-26 23:18:11 +02:00
Andreas Kieslinger 3fc4e10527 sched : reintroduce less synchronizations during split compute (#20793)
* CUDA:  Improve performance via less synchronizations between token (#17795)

* Adds CPU-to-CUDA copy capability to
ggml_backend_cuda_cpy_tensor_async()

* Adds function to relax sync requirements between input copies on
supported backends (CUDA for now)

* Exchanges synchronous copy with async copy function.

* Adds macro guards to allow compilation in non-CUDA builds

* Reworked backend detection in ggml-backend.cpp to avoid linking
conflicts

* Relax requirement of checks in async CUDA copies from backend and buffer type to just buffer type, to avoid linking issues

* Minor cleanup

* Makes opt-in to relax use of explicit syncs more general. Backends like
vulkan which require a synchronization between HtoD copies and graph
execution could also adopt this change now.

* Reintroduces stricter check for CPU->CUDA backend async copy via
GGML_DEVICE_TYPE_CPU.

* Corrects initialization of ggml_backend_sync_mode in
ggml_backend_sched_split initialization

* Simplifies synchronizations to adhere to `saaasg` pattern.

* Apply suggestion from @ggerganov (src->buffer to buf_src)

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Apply suggestion from @ggerganov (src->buffer to buf_src) v2

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Apply suggestions from @johannesgaessler code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* Adds single-GPU synchronizations to multi-GPU settings to fix hip backend pipeline parallel bugs.

* Scheduler Hardening: Exclude hip/MUSA from copy_from_host CPU split ->
GPU split optimization

* Scheduler Hardening: Re-adding original additional synchronizations for
non-async backends

* Adds disclaimer to hip/musa exclusion of copy_from_host. Highlights that it is out of
precaution, but that no perf-impact is visible, and that it can be
revisited separately anytime.

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2026-06-26 17:18:30 +03:00
Adrien Gallouët 5d8ccdf9d1 devops : add llama in all docker images (#25035)
Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-26 15:15:48 +02:00
Xuan-Son Nguyen 024930c6ad arg: fix handling --spec-draft-hf and --hf-repo-v (#25043)
* arg: fix handling --spec-draft-hf and --hf-repo-v

* fix missing mparams.hf_file
2026-06-26 14:36:03 +02:00
Ravi Panchumarthy 5397c36194 openvino: Update to OV 2026.2.1, self-contained release packages, operator improvements (#24974)
* Update to OV 2026.2.1, Make OV release packages self-contained

* Update to OV 2026.2.1, Make OV release packages self-contained

* OpenVINO Backend: Remove compute_op_type hardcoded sets (#222)

* OpenVINO Backend: Remove compute_op_type hardcoded sets

* revert get_op_type removal

* OpenVINO backend: enable softmax with sink input

* OpenVINO backend: opt mul_mat_id convert process for large size

* OpenVINO backend: Modify add_id to support 2D/4D

* OpenVINO Backend: Add glu_swiglu_oai

* PR review: fix paths

* PR review: fix path consistency

---------

Co-authored-by: Mostafa <mostafas.main.email@gmail.com>
Co-authored-by: Xuejun <Xuejun.Zhai@intel.com>
2026-06-26 15:07:19 +03:00
Georgi Gerganov e7ea94afcb sync : ggml 2026-06-26 15:04:42 +03:00
Georgi Gerganov 96183e9820 ggml : bump version to 0.15.3 (ggml/1550) 2026-06-26 15:04:42 +03:00
nullname 487a6cc164 vulkan: opt mul_mat_vecq for mi50 (#22933) 2026-06-26 13:49:24 +02:00
Jiang, Fish 5a6a0dd7e1 vulkan: add INTEL_XE1 arch enum and enable coopmat1 on Intel Xe-LPG Plus (#24404)
* vulkan: add INTEL_PRE_XE2 arch enum and enable coopmat1 on Intel Xe-LPG Plus (1/3, Xe1-ARLH)

Co-authored-by: Xia, Jie <jie.xia@intel.com>
Co-authored-by: Liu, Russell <russell.liu@intel.com>

* Address comments of bf16 and trailing whitespace

* Rename INTEL_PRE_XE2 to INTEL_XE1 and remove driver workaround

* Add Windows driver check

---------

Co-authored-by: Xia, Jie <jie.xia@intel.com>
Co-authored-by: Liu, Russell <russell.liu@intel.com>
2026-06-26 13:26:22 +02:00
Sanjay Ahari ded1561b42 ui: fix accessibility for hover-gated interactive elements assisted by claude(in debugging and tests) (#24727) 2026-06-26 12:55:38 +02:00
Jeff Bolz 9df06805ee vulkan: Workaround compiler bug in conv2d coopmat2 path (#24924)
* vulkan: Workaround compiler bug in conv2d coopmat2 path

* apply same workaround to CONV_3D

* Apply suggestion from @jeffbolznv
2026-06-26 11:53:32 +02:00
leonardHONG 2f18fe13c5 CUDA: add cublasSgemmBatched mapping for HIP/MUSA vendor headers (#25033) 2026-06-26 11:42:56 +02:00
Tarek Dakhran c16c35b814 ggml-cpu: fix SVE leftover path in ggml_vec_dot_f32 (#24699)
* ggml-cpu: fix SVE leftover path in ggml_vec_dot_f32

2D convolutions with kernel size 9 produced different results on SVE
enabled ARM devices. After debugging it turned out that ggml_vec_dot_f32
was using data from inactive lanes.

Use svmla_f32_m(pg, sum1, ax1, ay1) so inactive lanes retain sum1.

* cont : clean-up

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2026-06-26 10:41:56 +03:00
Pascal 1a87dcdc45 server + ui: SSE Replay Buffer (#23226)
* server: SSE replay buffer, survives client disconnect

Opt in on POST /v1/chat/completions when the client sends
X-Stream-Resume: 1 and a non empty X-Conversation-Id. The conv id is
the session identity end to end, no extra opaque token. The drain
runs detached server side and buffers SSE bytes, the generation
survives HTTP disconnect, F5, or lets users switch from iOS Safari
to another app without losing the actively generated response.

Routes:
  GET    /v1/stream/<conv_id>?from=N       replay
  GET    /v1/streams[?conversation_id=X]   list, drives sidebar spinners
  DELETE /v1/stream/<conv_id>              Stop, idempotent

Router parent fans out to children for list and delete, probes on GET
to route to the owner, fans out DELETE on POST so "one session per
conv" holds across model swaps.

WebUI: the layout snapshots /v1/streams at mount and on
visibilitychange, the sidebar reflects live inferences across all
convs. The chat page reattaches on mount, append vs fresh is detected
from existing content so continue mid stream keeps its prefix.

update_slots: on llama_memory_seq_rm refusal at a deep position, full
clear of the seq and reprefill from zero instead of GGML_ABORT.

OAI strict path unchanged when the opt in headers are absent.

* server: create stream session only after post_tasks succeeds

* server, ui: drop X-Stream-Resume, X-Conversation-Id alone enables the replay buffer

* server: drop magic 17, derive the X-Conversation-Id header length from sizeof at build time

* refactor: address review feedback from ngxson

* server-context: cleaning

* server-stream: fix use-after-free on rd

Guard stop_producer with a shared alive flag, flipped by on_stream_end
before rd dies. Prevents a late cancel (session eviction by a later
POST on the same conv_id, or a DELETE arriving after the producer
ended) from touching a destroyed rd.

* ui: fix cross-conversation contamination

Scope streaming flags per conv so one finishing does not unflag the
others, guard discoverActiveStream against concurrent runs to avoid
duplicate attaches, and stop racing syncRemoteRunningStreams for the
sidebar set.

* server-http: keep request alive in detached SSE drain

The response next() lambda may reach into *request via &req long
after on_complete reset the request shared_ptr. Capture request in
the detached thread so it outlives the drain.

* ui: address review feedback from coder543

Forward Authorization to /v1/stream and /v1/streams fetches, the resumable routes
must obey --api-key like the rest of the API.

Wrap reader.read() in a try/catch, the underlying connection drop rejects with
TypeError instead of resolving done=true, treat it as a premature end of stream
so the existing resume loop kicks in.

Freeze the model at session start in chatStreamingStates.model and thread it
through cancel and resume, the dropdown selection may have changed since the
POST and the server side identity is fixed at that time.

* format

* ui: remove unused selectedModelName

* server-stream: poll session->is_cancelled() in stream_aware_should_stop

Address review feedback from coder543. The cancel propagation through
rd.stop() relies on the slot eventually processing the cancel task and
posting a result that notifies the recv condvar, remove_waiting_task_ids
does not notify directly. Add a defensive poll on session->is_cancelled()
so the producer-side next() loop exits on its next iteration after
cancel() without waiting for the cancel task to round trip through a slot.

* server-stream, ui: replace GET /v1/streams with POST /v1/streams/lookup

Address review feedback from coder543. Listing live sessions leaks the
conversation_id of every concurrent user, which defeats the random UUID
unguessability. The new route takes {conversation_ids: [...]} in the
body and returns matches only for the ids the caller already owns, so
foreign UUIDs stay private. The router fans out the same POST to every
child and aggregates, the WebUI passes the convs visible in its sidebar.

* ui: read conv ids from IndexedDB in syncRemoteRunningStreams

The conversations store is not hydrated yet at +layout onMount, so the
sidebar spinners stayed off for background convs until the user clicked
on them. Read straight from the DB to dodge the init race.

* server-models: deduplicate stream lookup timeouts behind one constant

* ui: extract visibility kick grace into a stream constant, bump to 1000 ms

* make it safer & more simple

* server-stream: survive client disconnect via stream_pipe::finish_producer

After the RAII rewrite the generation stopped the moment the client
disconnected. httplib bails its content provider on the is_peer_alive
check at the top of write_content_chunked, so returning true from the
provider never keeps it producing: the response resets, rd is destroyed
and its task gets cancelled.

Reinstate the disconnect survival inside the pipe. stream_pipe gains
finish_producer, which pumps the response next() into the ring buffer
until the generation ends, and mark_producer_done for the clean wire
end. server-http only triggers them: mark before sink.done on a clean
close, finish in on_complete when the peer left early. No detach, no
stream logic in server-http beyond the trigger, and the strict OAI path
is untouched when no pipe is attached.

Known limitation: finish_producer pumps synchronously on the http
worker, so a disconnected stream keeps its worker busy until the
generation ends. A follow-up will move the drain off the http worker so
no worker is held.

* server-stream: drain disconnected streams on a manager owned thread

The previous commit pumped the post disconnect drain synchronously in
on_complete, on the http worker, so a disconnected stream kept its
worker busy until the generation ended. Under a wave of reloads or tab
closes that pins workers from the pool.

Move the drain off the http worker. on_complete now hands the response
to stream_session_manager::adopt_orphan, which pumps it to completion on
a manager owned thread and releases the worker at once. One thread per
disconnected stream still generating, stored in a list, joined and
reaped on the next adopt, by the GC, and at shutdown. No detach, the
thread lifecycle is fully owned by the manager. needs_drain gates the
handoff so a cleanly finished stream never spawns a thread, and the
strict OAI path stays untouched when no pipe is attached.

stop_gc now cancels sessions before finalizing them, so an in flight
drain sees is_cancelled and exits instead of blocking the shutdown join
until the generation ends naturally.

* ui: add missing JSDoc

* server-stream: drain on the http worker, drop the manager thread

Address @ngxson review: httplib runs a large dynamic pool and a worker
blocked in next() sits on a condvar instead of burning cpu, so draining
the rest of the generation on that worker is fine and much simpler than
a dedicated thread.

on_complete calls finish_producer directly again. Removes adopt_orphan,
the orphan thread list and its reaping, the stop_gc session cancel that
only existed to unblock those threads, and the now dead drain_shutdown
flag.

* server-stream: split stream_pipe into producer and consumer classes

Address @ngxson review: one class covering both ends was messy. stream_pipe
is now a base holding the session and is_cancelled, with stream_pipe_producer
(write, mark_producer_done, finish_producer, cleanup, finalizes on destruct)
and stream_pipe_consumer (read only, no finalize) deriving from it.

Drops the is_producer_ discriminator and its runtime guards, the type now
encodes the role. res.spipe is retyped to shared_ptr<stream_pipe_producer>
since it is only ever a producer. No behavior change.

* server-stream: rename producer methods to unix pipe semantics

Address @ngxson review: mark_producer_done becomes done(), finish_producer
becomes close(), matching a unix pipe write end. The producer_done_ member
follows as done_. write() is unchanged. No behavior change.

* server, ui: route resumable streams via a conv map, persist resume identity

Address ngxson review: drop the polling probe, proxy_post records a conv_id ->
model map and the stream routes resolve the owning child with one lookup. The
map is the single source of truth, the ::model suffix stays for child session
uniqueness but the router never parses it.

UI: the server keys a session by the POST time identity (conv::model), but reload
probed with the bare conv id and missed model tagged sessions, so F5 stopped the
stream and sidebar spinners stayed off. Persist the model and rebuild the exact
identity on resume, single conv and bulk sidebar both send it.

Add unit coverage for the identity round trip.

* ui: resolve continue target by id to stop cross-conversation flash on switch

* ui: skip stream resume when the abort is intentional

* server: move the conv id to model map into a self contained tracker

Address review from ngxson: server_models held two mutexes side by side, the
global one and a bare conv_model_mu guarding a loose map, which made the locking
hard to follow. Wrap the map and its lock in a small conv_model_tracker struct
that owns its mutex, one mutex per struct. The remember, lookup and forget
methods move inline into the tracker, server_models exposes a single conv_models
member and the routes call models.conv_models.lookup and friends. No behavior
change, the map stays the single source of truth for routing resumable streams
to a child.

* ui: replace stream magic values with enums and shared constants

Address review from allozaur: lift the inline literals around the resumable
stream code into named symbols so the intent is explicit and reusable.

* ui: fold the stream resume and discovery helpers into ChatService

Address review from allozaur: drop the two standalone stream-*.service files.
They were used only by the chat service and store, carried no shared state, and
did not follow the static class pattern the other services use, so a separate
abstraction was not warranted. Move the helpers onto ChatService as static
methods. No behavior change, tests now exercise them through ChatService.

* docs: document the SSE replay buffer in server README-dev

Add the resumable streaming section, list stream_session_manager in the
backend component inventory, and link PR 23226 in the related PRs.

* ui: align attachServerStream call with onCompletionId param in handleStreamResponse

* server-http: rename del_ to del to match get and post

* ui: address review feedback from allozaur

* ui: drop duplicate SSE constants, keep sse.ts canonical

* ui: use svelte:document for the visibilitychange listener

address review from allozaur: replace the manual document.addEventListener
in onMount with a declarative <svelte:document onvisibilitychange>. svelte
handles attach, detach and SSR, so the typeof document guard and the onMount
cleanup go away. onMount keeps only the first load snapshot.

* server: trim redundant stream drain comments

Address review from ngxson

* server: balance and clean up stream comments

remove redundant comments and tighten the verbose ones across the resumable
stream code, keeping the concurrency and lifetime rationale that is not obvious
from the code. also fix two stale comments in server.cpp and server-models.h
that still described the old ::model suffix probe and fan out routing, now
replaced by the conv_id -> model map

Address review from ngxson

* ui: balance and clean up stream comments

dedup repeated rationale (frozen conv::model identity, the lookup privacy note,
the abort patterns) down to one canonical spot, tighten the verbose blocks, and
keep the concurrency and resume-offset reasoning. fix stale comments in
stream-identity.ts and chat.service.ts that still described the old loopback
probe and fan out routing, now the conv_id -> model map.

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
2026-06-26 09:31:29 +02:00
Jassieluo e7e3f35090 sycl : clamp softmax input to avoid underflow (#24941) 2026-06-26 15:02:42 +08:00
Xuan-Son Nguyen b11f7c16bc mtmd: add more validations (#25013)
* mtmd: add more validations

* fix

* refactor a bit

* type check for get_arr_int
2026-06-26 08:43:29 +02:00
leonardHONG f818065d75 CUDA: batch out_prod broadcast (dps2>1) path with cublasSgemmBatched (#24426) 2026-06-26 08:51:25 +03:00
Arsen Arutunan 960d628f46 mamba2: remove hardcoded 2x expansion factor and invalid d_inner % d_state check (#23082)
* mamba2: remove hardcoded 2x expansion factor, support any expand value

* mamba2: remove invalid d_inner %% d_state check (unrelated parameters)

* Update convert_hf_to_gguf.py: make expand optional with default 2

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* mamba2: apply expand fix to refactored conversion/mamba.py

* also check for mamba_expand

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: Sigbjørn Skjæret <1629204+CISC@users.noreply.github.com>
2026-06-26 08:50:54 +03:00
shaofeiqi 5c7c22c3e1 opencl: flush profiling batch at shutdown for incomplete batches (#25016) 2026-06-25 18:48:24 -07:00
Sigbjørn Skjæret beac5309f1 xcframework : disable mtmd video on i/tv/visionos (#25018) 2026-06-26 00:13:59 +02:00
Tarek Dakhran 9d5d882d8c model : Add label for LFM2.5-230M (#25008) 2026-06-25 18:58:52 +02:00
Oliver Simons 1ec44d178d CUDA: Various fixes to cpy.cu (#25000)
* Add failing test-case to test-backend-ops

Extracted from https://github.com/ggml-org/llama.cpp/issues/24072

* Minimize repro with help of AI

N = 8 * (65535 - 1) + 1 = 524273

* Port and adjust workaround from https://github.com/LostRuins/koboldcpp/commit/0ba798341e0c70517cb226cb63c966b086a3b5b3

Fall-back should share code, also relax y-z constraint to be inclusive

* Add test-case + fallback also for y dim

* Fix x-guards which is 2^{31}-1, so inlusive of INT_MAX

* Fix overflow problems for transposed copy kernel
2026-06-25 17:29:23 +02:00
Xuan-Son Nguyen c7cddefcbd misc: fix labeler (#25012) 2026-06-25 17:23:37 +02:00
Xuan-Son Nguyen e9d1b76d0a server: use status code 403 for disabled features (#24970)
* server: use status code 403 for disabled features

* cont

* fix test case
2026-06-25 16:36:40 +02:00
Xuan-Son Nguyen 099bf06952 misc: update lables (#24920)
* misc: update lables

* bring back examples, add mtmd
2026-06-25 16:26:56 +02:00
Xuan-Son Nguyen 60bc8866b1 common: refactor model handling (#24980)
* common: refactor models handling

* remote preset

* cont

* rm skip_download option

* missing header

* fix plan.model_files

* fix --offline case

* move hf_plan to download

* refactor

* rm redundant curr_ex, add comments

* adapt
2026-06-25 15:17:51 +02:00
Kashif Rasul e8ecce53b8 docs : Eagle3 qwen3 draft model support (#24977)
* eagle3: accept Eagle3LlamaForCausalLM draft checkpoints

* docs: add eagle3 speculative decoding section

* docs: address eagle3 review comments

* docs: add more angelslim eagle3 models

* docs: add gpt-oss eagle3 models and link to pr 18039
2026-06-25 15:58:00 +03:00
Adrien Gallouët 683b04cc4a app : add the llama download subcommand (#24982)
* app : add the download command (with llama-download)

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

* Remove llama-download tool for now

Signed-off-by: Adrien Gallouët <angt@huggingface.co>

---------

Signed-off-by: Adrien Gallouët <angt@huggingface.co>
2026-06-25 13:36:36 +02:00
fairydreaming f728adab68 ggml : address integer overflows in binary ops CUDA implementation (#24706)
* ggml : address integer overflows in binary ops CUDA implementation

* ggml : add size_t casts to avoid integer overflows

* ggml : add more asserts checking integer overflows in binary ops CUDA implementation

---------

Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
2026-06-25 10:06:44 +02:00
Pascal 3e61ea0e2f ui: fix always-show-sidebar-on-desktop setting after navigation refactor (#24979) 2026-06-25 09:45:55 +02:00
Christopher Albert fdbd6abee2 tests : synchronize contexts at end of test-thread-safety (#24935)
Assisted-by: Claude
2026-06-25 09:22:51 +03:00
Abraham Gonzalez e12a0128ab build: include libmtmd in Apple XCFramework (#21935)
Adds an opt-in LLAMA_BUILD_MTMD CMake option so build-xcframework.sh
can link libmtmd.a into the framework binary without pulling in the
rest of tools/ (which doesn't cross-build cleanly to iOS/tvOS/visionOS).

- CMakeLists.txt: new option, default OFF. When on with
  LLAMA_BUILD_TOOLS=OFF, only the tools/mtmd subdir is added. Useful
  for any binding that wants just libmtmd (Apple XCFramework, WASM).
- tools/mtmd/CMakeLists.txt: gate the CLI exe targets on
  LLAMA_BUILD_TOOLS. Gating on LLAMA_BUILD_COMMON is not enough — it
  defaults ON in standalone builds and visionOS xcodebuild then fails
  with "install TARGETS given no BUNDLE DESTINATION for MACOSX_BUNDLE
  executable target 'llama-mtmd-cli'".
- build-xcframework.sh: turn the option on, pass -DLLAMA_BUILD_MTMD,
  add libmtmd.a to combine_static_libraries, and copy mtmd.h and
  mtmd-helper.h into the framework Headers dir. The umbrella module
  map then exposes them, so Swift / Obj-C consumers can import the
  mtmd C API directly.

After this, nm on ios-arm64/llama.framework/llama shows 52 _mtmd_
symbols. Verified end-to-end: a Swift target links the produced
framework and calls mtmd_default_marker, mtmd_bitmap_init, etc.
without a shim on macos / iphoneos / iphonesimulator / xros slices.

Co-authored-by: Abraham Gonzalez <abraham@theabecaster.com>
2026-06-25 08:37:30 +03:00
Sigbjørn Skjæret b3ce5cedf4 quant : fix quantizing moe with mtp (#24986) 2026-06-25 08:36:49 +03:00
David Spruill e9fb3b3fc0 sycl : support --split-mode tensor (#24152)
* Sycl tp stage1 (#1)

* SYCL: tensor parallelism (--split-mode tensor) for dual-GPU

Adds the comm_init/comm_free/comm_allreduce_tensor trio that the
meta-backend queries via get_proc_address to enable backend-specific
all-reduce, mirroring the pattern used by ggml-cuda.cu.

For N=2 (the common dual-GPU case) implements a degenerate ring
all-reduce with two size-branched paths:

  * Small (nelem < 32768): FP32 direct memcpy + per-device ADD kernel
    chained via depends_on(memcpy_event). 4 SYCL submissions/call.

  * Large (nelem >= 32768): BF16-compressed. Each device compresses
    FP32 -> BF16 in a local outbox, cross-device memcpys to the peer's
    inbox (HALF the PCIe bytes), then decompresses + adds into the
    local FP32 partial. 6 SYCL submissions/call but PCIe bytes halved
    -- wins for any tensor where PCIe dominates kernel time.

Threshold and BF16 path pattern mirror the CUDA NCCL allreduce.

Storage: ONE persistent uint8_t buffer per device, 4 * nelem bytes
(matches both path layouts: FP32 nelem floats; BF16 outbox+inbox =
2 * nelem uint16_t each). Single alloc+free per device keeps the
SYCL pool's strict-LIFO invariant trivial.

Initial impl handles N=2 FP32 contiguous tensors. Other cases return
false, causing the meta-backend to use its generic butterfly fallback.

Per-call sync is intentionally omitted. SYCL in-order queue semantics
ensure that the meta-backend's next compute on the same per-device
queue waits for our final ADD, and the next allreduce's first op on
the same persistent buffer waits via the same queue. Only comm_free
does an explicit final wait.

OneCCL is NOT used: OneCCL 2021.17 hardcodes single-device-per-process
in communicator_impl.hpp:47 (condition devices.size() == 1), which is
incompatible with llama.cpp's single-process multi-GPU model.

Measured on dual Intel Arc Pro B70 (NEO 26.05.x, oneAPI 2025.3 +
DPC++ nightly):

  Llama-3.3-70B Q4_K_M, -sm tensor -fa 1 -ctk f16 -ctv f16:
    pp512 = 377.08 t/s  (vs 313.65 layer mode = +20.2%)
    tg128 = 17.40 t/s   (vs   9.74 layer mode = +78.6%)

  Qwen3-Coder-Next-80B-A3B Q3_K_M (MoE):
    pp512 = 216.56 t/s  (vs 156.58 meta-backend butterfly = +38.3%)
    tg128 = 17.60 t/s   (vs  14.31 meta-backend butterfly = +23.0%)

  Qwen3-4B Q4_K_M:
    pp64  = 984.51 t/s, tg16 = 49.29 t/s

Llama-3.3-70B in SYCL TP now comfortably beats production layer mode
on both prefill and decode. Coder-Next-80B-A3B (MoE) also wins on
both — the BF16 path is what unlocks the many-medium-allreduces
prefill pattern.

Build/CMake: no changes. No new dependencies. ~210 lines added across
ggml-sycl.h and ggml-sycl.cpp.

* Fix comments

* documentation update to address PR feedback

* Bring over my device-to-device memcpy chagnes

* move the dev2dev_memcpy calls to the upstream 7-parameter variety

* Fix a typo and remove a trailing whitespace
2026-06-25 08:35:21 +03:00
Neo Zhang 9c10954865 sycl : fix the failed UT cases of conv_3d (#24900) 2026-06-25 08:27:58 +03:00
lhez fdb2c11c70 opencl: support non-contig rows in norm (#24965) 2026-06-24 19:21:25 -07:00
Piotr Wilkin (ilintar) 09cedfd699 chat: harden caps check (#24973) 2026-06-25 02:49:22 +02:00
Max Krasnyansky 8be759e6f7 hexagon: MUL_MAT and MUL_MAT_ID rework : 32x32 tiled weight repack, kernel-params, cached graphs (#24954)
* hex-mm: new weight layout and fusion updates

* hvx-mm: unroll the new tiled vec_dots to optimize hvx register util

* hex-mm: optimize dyn.quant format for q8_0 and q8_1 to reduce overhead in vec_dots.

* hvx-mm: parallel quantizer per block for large rows

* hvx-mm: simplify and futher optimize dyn.quant and vec_dots

* hvx-mm: keep intermediate per tile accumulators in fp16

* hmx-mm: optimize weight dequant by aligning the repacked tiles with the DMA

* hmx-mm: remove qweight scratch and just use vtcm_weight

* hmx-mm: remove all unused and obsolete code

* hmx-mm: the new tiled repack format is here to stay -- rename all x4x2 to _tiled

* hmx-mm: improve activation processing with dma prefetch

* hex-mm: fix hmx/hvx fallback logic and MUL_MAT_ID allocation (unbreaks OLMoE)

* hex-mm: align the weight tiles with dma just like we did in hmx-mm

* hex-mm: factor out common mm bits into htp/matmul-ops.h

* hex-mm: start moving mm kernel selection to the host

* hex-mm: move all of the matmul param compute into the host

* hmx-mm: restore pipelined mode

* hmx-mm: unroll the dequant functions to optimize register usage

* hmx-mm: further improve activation process

* hex-mm: use vtcm_seq_alloc for all vtcm allocations and define more common functions

* hex-mm: improve mm optimizer to acount for number of activation threads

* hex-mm: fix matmul-id kernel params selection (unbreaks OLMoE and LFM)

* hexagon: remove support for arch < v73 since HMX is now required for most use-cases

* hex-mm: cleanup naming for consistency

* hex-mm: make sure matmul fusion accounts for vtcm allocation

* hex-mm: minor cleanup for kernel_params definition

* hex-mm: replace hardcoded limits with proper checks for vtcm requirements

* hex-mm: add support for non-tiled mm as a fallback option and factor out hvx kernels into separate header

* hex-mm: remove unused functions

* hex-mm: add shorthand for MM_SELECT in run-tool script

* hvx-mm: factor out hvx/hmx microkernels and unify matmul entry and dispatch

* hex-mm: further cleanup matmul fallback path

* hex-mm: refactor matmul entry point and dispatch a bit further

* hexagon: update cmake build to enable hmx for everything

* hex-ops: optimize kernel_param updates and include summary in the logs

* hex-mm: add support for GGML_HEXAGON_MM_SELECT

* hex-mm: add hex-common header

* hex-mm: pass correct number of tasks to workpool

* hex-mm: add proper checks for no-work in dyn.quant tasks

* hex-mm: convert all quantizers into a macro

* hex-mm: fix hvx-flat fallback to pass all MUL_MAT tests

* hex-mm: vectorize q8_1 quantizer

* hex-mm: improve fused ffn mm stride handling

* hex-mm: consistent use of n_threads and pipeline in kernel_params

* hexagon: minor formatting

* hex-mm: update MUL_MAT_ID kernel_param handling to make sure host/npu are in sync

* hvx-mm: go back to accumulating in fp32 in tiled hvx kernels, more accurate and same perf

* hvx-mm: unroll the loops and remove masking that is not needed for tiled accums

* hmx-mm: optimize activation processing (slit loops, some unrolling, etc)

* hmx-mm: minor optimization for output processing

* hex-mm: consistent use of uint32_t and size_t in mm kernels

* hex-mm: remove legacy restrictions for rows to be multiple of 256

* hexagon: replace sprintf with snprintf

* hex-mm: relax hardcoded nrows checks and rely on VTCM size requirements

* hexagon: minor alignment fix

* hexagon: fix trailing spaces

* hex-mm: relax padding from 256 to 128 (leftovers)

* hex-mm: remove redundant checks for weight align to 128

we always use 2D dma for the weights and align them properly

* hmx-mm: MUL_MAT_ID better work distribution between hvx threads and hmx tracing

* hex-mm: specialize per-token mmid activation handling

* hex-profile: update python scripts to handle kernel-params section in the logging output

* hex-mm: move n_prefetch (aka dma_depth) into kernel params and remove unused fields

* hex-trace: use easier to parse format, simply and fix post-proc scripts

* hmx-mm: relax 32 row limit for output processing which helps utilization

* hmx-mm: use start-chunk idx for tracing info

* hmx-mm: parameterize activation dma pipeline

* hexagon: add support for simple graph caching to avoid recomputing kernel-params

* hex-mm: remove left-over repack functions

* hex-mm: tighten n_prefetch asserts

* hex-mm: remove duplicate round/align_up helper

* hexagon: cleanup common header used in host/npu

* hexagon: update early wakeup threshold

* hmx-mm: define cost constants and update solver to assume that repacked ne[1] is padded to 32

* hmx-mm: make precompute_matmul a bit more readable (split into smaller functions, etc)

* hex-mm: remove n_threads constraint

* hex-mm: minor formatting updates

* hex-mm: remove obsolete profiling logs

* hex-mm: restore hardcode gate to refuse lm-head to avoid repacking that tensor
2026-06-24 12:14:25 -07:00
Saba Fallah 894bb27af3 mtmd: model: unlimited-ocr: converter + parity test (#24969) 2026-06-24 18:20:22 +02:00
Xuan-Son Nguyen fb401045cc common: remove unused json-partial (#24968) 2026-06-24 18:12:16 +02:00
Wagner Bruna 51eae8cfca vulkan: allow reducing the graph submission batches to avoid timeouts (#24872) 2026-06-24 16:29:24 +02:00
liminfei-amd 1191758c5d vulkan: fail the build when a shader fails to compile (#24450)
* vulkan-shaders-gen: fail the build when a shader fails to compile

vulkan-shaders-gen did not detect shader-compile subprocess failures, so a
broken libggml-vulkan could be produced while the build reported success and
the breakage only surfaced at run time. execute_command() discarded the child
exit code (POSIX waitpid passed nullptr for status; the Windows branch never
called GetExitCodeProcess) and string_to_spv decided success only from whether
stderr was empty, so a non-zero exit with empty stderr, or a subprocess that
failed to launch, was treated as success.

Return the child exit code from execute_command() (WEXITSTATUS on POSIX,
GetExitCodeProcess on Windows), treat a non-zero exit or non-empty stderr or a
launch exception as a failure, and record it in an atomic flag. main() checks
the flag after process_shaders() and returns EXIT_FAILURE before writing the
output files, so the build stops instead of emitting a broken backend.

Fixes #24393

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>

* vulkan-shaders-gen: simplify compile_failed access and drop unreachable return

Address review feedback on #24450:
- Access the std::atomic<bool> compile_failed directly (= / implicit bool)
  instead of .store()/.load(); the flag stays atomic because the worker
  threads in process_shaders() set it concurrently.
- Remove the unreachable trailing return -1 in execute_command(): on POSIX the
  child _exit()s after execvp and the parent returns (fork()<0 throws); on
  Windows the block returns the exit code.

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>

---------

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
2026-06-24 11:42:03 +02:00
Pascal 00139b660b ui: loading bar below the model picker (#24931)
* ui: show model load progress on the selector trigger

Mirror the in-dropdown stage progress as a thin bar on the selector
trigger, so the active model's load percent stays visible when the menu
is closed. Same status gating and composite fraction as the dropdown
row, so both bars track the selected model in sync.

Suggested-by: Julien Chaumond <@julien-c>

* ui: show model load progress bar on the in-conversation model selector

* ui: tune model load indicator to a pulsing highlight (suggested by @ngxson)

Also wire the indicator onto the mobile sheet trigger, which was missing
it since mobile uses the sheet instead of the dropdown.

* ui: thin (@allozaur) pulsating (@ngxson) model load bar
2026-06-24 10:50:44 +02:00
Aleksander Grygier ef9c13d4c2 ui: New Logo + Navigation cleanup & Mobile UI/UX improvements (#24897)
* chore: `npm audit fix --force`

* feat: Update sidebar toggle to use Logo

* refactor: Clean up favicon SVG

* feat: Refactor logo component and implement theme-aware favicon generation

* feat: Add configurable padding to generated PWA assets

* test: Add unit tests for writeThemeFavicons

* refactor: Componentization

* feat: WIP

* feat: WIP

* feat: WIP

* feat: Mobile UI

* feat: add SEARCH route constant

* feat: create SidebarNavigationSearchResults component

* refactor: use SidebarNavigationSearchResults in conversation list

* feat: enable mobile search navigation in sidebar actions

* feat: add mobile search route and page

* fix: prevent sidebar overflow on mobile viewports

* fix: Mobile sidebar

* feat: Mobile Search WIP

* feat: Mobile WIP

* feat: Add PWA standalone detection and refine mobile UI

* feat: Improve mobile layout, sidebar handling, and chat scrolling

* feat: Improve mobile sidebar visibility and iOS Safari chat spacing

* fix: Disable auto-scroll on mobile

* chore: Linting

* fix: Wrong condition

* feat: Mobile chat scroll

* refactor: WIP

* fix: Desktop initial scroll always working again

* fix: Partial fix for mobile auto-scroll / initial scroll

* fix: Desktop auto-scroll on initial load and during streaming

* fix: Mobile scrolling logic

* refactor: Clean up

* feat: Improve start UI

* feat: Add `delay` to `fadeInView`

* feat: Auto-scroll button

* refactor: Cleanup

* refactor: Extract chat dialogs and alerts into dedicated component

* refactor: Reorganize ChatScreen component structure and initialization

* feat: Improve auto-scroll after sending message

* feat: UI improvements

* fix: Settings link

* feat: UI improvements

* fix: better UI spacing

* fix: Remove unneeded logic

* fix: Chat Processing Info UI rendering

* feat: Improve mobile UI

* feat: UI improvement

* fix: Conditional transition delay for Chat Messages based on route from

* fix: Delay mobile sidebar collapse for smoother transitions

* fix: Mobile scroll down button + sidebar pointer events

* fix: Mobile UI

* fix: Auto scrolling

* fix: Implement dynamic height calculations for chat auto-scroll positioning and UI elements

* fix: Retrieve `autofocus` for Chat Form textarea

* fix: Use proper class

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* refactor: extract scroll-to-bottom logic and fix message send flow

* fix: update viewport store usage and remove conflicting autofocus

* feat: add accessibility labels to scroll down button

* fix: correct HTML structure in sidebar empty states

* fix: dynamically toggle processing info visibility

* chore: remove commented exports and fix formatting

* fix

* fix: Mobile Chat Form Add Action Sheet interactions

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-06-24 10:21:33 +02:00
Tarek Dakhran 88636e178f model : Add LFM2.5-ColBERT-350M and LFM2.5-Embedding-350M (#24913)
* model : Add LFM2.5-ColBERT-350M and LFM2.5-Embedding-350M

* Restore LFM2 models in README.md
2026-06-24 09:49:46 +03:00
Jeff Bolz ac4105d68b vulkan: Apply bias before softmax in FA, to avoid overflow (#24909) 2026-06-23 22:34:00 -05:00
kononnable be4a6a63eb server : check draft context creation error (#24922) 2026-06-23 16:56:50 +02:00
Jeff Bolz 72a9269172 vulkan: support all backend tests for SQR/SQRT/SIN/COS/CLAMP/LEAKY_RELU/NORM (#24582)
* vulkan: make SQR/SQRT/SIN/COS/CLAMP/LEAKY_RELU use unary.comp

* vulkan: make NORM support noncontig

* add noncontiguous row test cases for norm/l2_norm, handle this in the CPU backend and l2_norm.comp

* fix supports_op for cuda and webgpu
2026-06-23 09:48:24 -05:00
Jeff Bolz 92e854ab83 vulkan: Support GET_ROWS_BACK (#24883) 2026-06-23 15:39:37 +02:00
Jeff Bolz c5606364b2 vulkan: support CONV_3D (#24612)
* vulkan: support CONV_3D

This is a pretty direct port of conv2d_mm.comp to CONV_3D, done by codex
and cleaned up by me.

* disable slower perf tests
2026-06-23 15:39:20 +02:00
Jeff Bolz 0eb874d374 vulkan: make mul_mm ALIGNED a spec constant (#24689)
This trims down some of the shader variant explosion and reduces binary size.
2026-06-23 14:26:17 +02:00
Xuan-Son Nguyen 75ad0b23ed server: fix remote preset handling, add test (#24938)
* server: add test for remote preset

* fix remote preset handling

* fix

* fix test
2026-06-23 13:28:34 +02:00
Wyatt Caldwell c926ad0985 vulkan: link ggml-cpu when GGML_VULKAN_CHECK_RESULTS / RUN_TESTS are enabled (#24444)
The result-checking and test debug paths in ggml-vulkan.cpp call ggml_graph_compute_with_ctx() to compute a CPU reference graph, but that symbol is defined in ggml-cpu, which ggml-vulkan does not link. Enabling -DGGML_VULKAN_CHECK_RESULTS=ON (or -DGGML_VULKAN_RUN_TESTS=ON) therefore fails to link with an unresolved external (e.g. LNK2019 on MSVC, undefined reference on GCC/Clang). This regressed after ggml-cpu was split into its own library. Link ggml-cpu under those two options so the debug builds link again.

Signed-off-by: Wyatt Caldwell <218154709+Detensable@users.noreply.github.com>
2026-06-23 12:55:46 +02:00
Gabe Goodhart a3900a6694 model: Granite Speech Plus (#24818)
* feat: Add conversion support for Granite Speech Plus

Branch: GraniteSpeechPlus
AI-usage: full (Bob, OpenCode + Qwen3.6-35b)
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* feat: Extend granite_speech to support plus multi-layer concatenation

Branch: GraniteSpeechPlus
AI-usage: draft (Bob, OpenCode + Qwen3.6-35b)
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(conversion): Fix plural naming for feature_layers for audio

Branch: GraniteSpeechPlus
AI-usage: none
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* fix(mtmd): Align feature_layer usage and naming everywhere

Branch: GraniteSpeechPlus
AI-usage: none
Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

* style: Use fstring for log

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>

---------

Co-authored-by: Xuan-Son Nguyen <thichthat@gmail.com>
2026-06-23 12:03:31 +02:00
Masashi Yoshimura 7c908502ea ggml-webgpu: improve MTP inference by using mat-vec path for small batches (#24811)
* ggml-webgpu: improve small batches decoding

* Add barrier to the NUM_COLS loop in mul-mat-vec
2026-06-23 17:13:55 +09:00
Masashi Yoshimura 035cd8f9a6 codeowners: add yomaytk to ggml-webgpu (#24930) 2026-06-23 15:19:34 +09:00
Aldehir Rojas 73618f27a8 server: improve user message detection and create checkpoints at every user message (#24176)
* server : improve message span logic

* cont : cast size_t to int32_t in comparisons

* server : create checkpoints before every user msg

* chat : remove \n in gemma4 delimiters

* chat : merge msg delimiter structs into one

* cont : reword comment

* cont : initialize tokens in delimiter

* cont : add server_tokens::get_raw_tokens() for mtmd

* cont : move message finding to server_tokens and skip mtmd tokens

* cont : update cohere2moe parser

* cont : increase min-step to 8192 and always produce a chkpt for last user message
2026-06-23 08:27:28 +03:00
Shawn Gu 23ee8797e1 opencl: q8_0 gemv precision improvement (#24923) 2026-06-22 22:25:21 -07:00
Matt Thompson dec5ca5577 server : Add id to tool call responses api (#24882) 2026-06-22 23:03:12 +02:00
Mahdiou Diallo 9c0ac887f3 ui: Prioritize favorite models in model selection (#24766)
Updated model selection prioritization to include favorite models.
2026-06-22 21:00:21 +02:00
Xuan-Son Nguyen 721354fbdf server: (router) move model downloading to dedicated process (#24834)
* server: real-time model load progress tracking via /models/sse

* update docs

* server: move model download to child process

* rm unused

* fix most problems

* clean up

* nit fixes

* fix test case

* do not detact() thread

* shorter MODEL_DOWNLOAD_TIMEOUT in test

* throttle
2026-06-22 18:24:04 +02:00
Xuan-Son Nguyen 6ee0f65793 server: refactor/generalize input file schema (#24299)
* server: refactor/generalize input file schema

* wire up input_video, accept raw base64

* nits

* nits (2)

* fix windows
2026-06-22 16:42:47 +02:00
Pascal 099b579acb ui: model status and load progress via /models/sse feed (#24878)
* ui: model status and load progress via /models/sse feed

* ui: centralize SSE wire-format delimiters into shared constants for the chat and /models/sse parsers

* ui: type /models/sse event names as a ServerModelsSseEventType enum

Address review from allozaur
2026-06-22 15:55:30 +02:00
Neo Zhang f8cc15f163 [SYCL] support bf16 on bin_bcast OP and unary OPs (#24838)
* support bf16 on bin_bcast OP and unary OPs

* support the older Intel compiler than 2026.0
2026-06-22 14:09:02 +03:00
Tim Neumann 37957e8531 sampling : remove unconditional softmax+sort in top-n-sigma sampler (#22645) 2026-06-22 14:08:32 +03:00
Pascal d0f9d2e5ac server: fix edit_file crash on append at end of file (line_start -1) (#24893)
line_start -1 normalized to n+1, so append inserted at lines.begin() + n + 1,
one past end() -> heap-buffer-overflow in vector::_M_range_insert.

Normalize -1 to n (insert at end()), restrict -1 to append mode and reject it
for replace/delete instead of silently clobbering the last line. Parenthesize
the insert offset so empty-file append computes the position as int first,
avoiding a transient begin() - 1 on a null vector data pointer.
2026-06-22 10:55:28 +02:00
aafsmarak 0ef6f06d55 docs/android.md: Add dependency libandroid-spawn for building in termux (#21812)
Fixes https://github.com/ggml-org/llama.cpp/issues/18615
2026-06-22 05:48:31 +02:00
Aldehir Rojas 52b3df0023 common/peg : implement ac parser for stricter grammar generation (#24869)
* common/peg : implement ac parser

* cont : extract functions

* cont : tidy up

* cont : remove a test

* cont : move ac() def
2026-06-21 16:20:58 -05:00
Xuan-Son Nguyen 7c082bc417 server: fix report progress for loading spec models, add "stages" list (#24870)
* server: fix report progress for loading spec models, add "stages" list

* improve

* nits

* nits 2
2026-06-21 17:36:52 +02:00
317 changed files with 24248 additions and 13612 deletions
+2 -2
View File
@@ -145,7 +145,7 @@ ENTRYPOINT ["/app/tools.sh"]
# ==============================================================================
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
ENTRYPOINT [ "/app/llama-cli" ]
@@ -156,7 +156,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
HEALTHCHECK --interval=5m CMD [ "curl", "-f", "http://localhost:8080/health" ]
+2 -2
View File
@@ -104,7 +104,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -115,7 +115,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+2 -2
View File
@@ -113,7 +113,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -124,7 +124,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+2 -2
View File
@@ -141,7 +141,7 @@ ENTRYPOINT ["/app/tools.sh"]
FROM base AS light
COPY --from=build /app/lib/ /app
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -153,7 +153,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/lib/ /app
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+2 -2
View File
@@ -115,7 +115,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -126,7 +126,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+8 -8
View File
@@ -1,12 +1,12 @@
ARG OPENVINO_VERSION_MAJOR=2026.2
ARG OPENVINO_VERSION_FULL=2026.2.0.21903.52ddc073857
ARG OPENVINO_VERSION_MAJOR=2026.2.1
ARG OPENVINO_VERSION_FULL=2026.2.1.21919.ede283a88e3
ARG UBUNTU_VERSION=24.04
# Intel GPU driver versions. https://github.com/intel/compute-runtime/releases
ARG IGC_VERSION=v2.34.4
ARG IGC_VERSION_FULL=2_2.34.4+21428
ARG COMPUTE_RUNTIME_VERSION=26.18.38308.1
ARG COMPUTE_RUNTIME_VERSION_FULL=26.18.38308.1-0
ARG IGC_VERSION=v2.36.3
ARG IGC_VERSION_FULL=2_2.36.3+21719
ARG COMPUTE_RUNTIME_VERSION=26.22.38646.4
ARG COMPUTE_RUNTIME_VERSION_FULL=26.22.38646.4-0
ARG IGDGMM_VERSION=22.10.0
# Intel NPU driver versions. https://github.com/intel/linux-npu-driver/releases
@@ -214,7 +214,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app/
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app/
WORKDIR /app
@@ -225,7 +225,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app/
COPY --from=build /app/full/llama /app/full/llama-server /app/
WORKDIR /app
+2 -2
View File
@@ -127,7 +127,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -138,7 +138,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+2 -2
View File
@@ -124,7 +124,7 @@ WORKDIR /llama.cpp/bin
# Copy llama.cpp binaries and libraries
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin/llama-completion /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama /llama.cpp/bin/llama-cli /llama.cpp/bin/llama-completion /llama.cpp/bin
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
@@ -138,7 +138,7 @@ WORKDIR /llama.cpp/bin
# Copy llama.cpp binaries and libraries
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
COPY --from=collector /llama.cpp/bin/llama /llama.cpp/bin/llama-server /llama.cpp/bin
EXPOSE 8080
+2 -2
View File
@@ -107,7 +107,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -118,7 +118,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+2 -2
View File
@@ -97,7 +97,7 @@ ENTRYPOINT ["/app/tools.sh"]
### Light, CLI only
FROM base AS light
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
WORKDIR /app
@@ -108,7 +108,7 @@ FROM base AS server
ENV LLAMA_ARG_HOST=0.0.0.0
COPY --from=build /app/full/llama-server /app
COPY --from=build /app/full/llama /app/full/llama-server /app
WORKDIR /app
+26 -19
View File
@@ -35,8 +35,20 @@ AMD ZenDNN:
documentation:
- changed-files:
- any-glob-to-any-file:
- "**/*.md"
- docs/**
- media/**
examples:
- all:
- changed-files:
- any-glob-to-any-file:
- app/**
- examples/**
- tools/**
- all-globs-to-all-files:
- '!tools/server/**'
- '!tools/mtmd/**'
- '!tools/ui/**'
testing:
- changed-files:
- any-glob-to-any-file:
@@ -47,28 +59,12 @@ build:
- cmake/**
- CMakeLists.txt
- CMakePresets.json
examples:
- changed-files:
- any-glob-to-any-file:
- examples/**
- tools/**
devops:
- changed-files:
- any-glob-to-any-file:
- .devops/**
- .github/**
- ci/**
python:
- changed-files:
- any-glob-to-any-file:
- "**/*.py"
- requirements/**
- gguf-py/**
- .flake8
script:
- changed-files:
- any-glob-to-any-file:
- scripts/**
android:
- changed-files:
- any-glob-to-any-file:
@@ -81,9 +77,20 @@ server:
- changed-files:
- any-glob-to-any-file:
- tools/server/**
mtmd:
- changed-files:
- any-glob-to-any-file:
- tools/mtmd/**
conversion:
- changed-files:
- any-glob-to-any-file:
- conversion/**
- convert_*.py
- gguf-py/**
vendor:
- changed-files:
- any-glob-to-any-file:
- vendor/**
ggml:
- changed-files:
- any-glob-to-any-file:
+4 -4
View File
@@ -68,8 +68,8 @@ jobs:
env:
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Clone
@@ -96,8 +96,8 @@ jobs:
env:
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Clone
+4 -4
View File
@@ -39,8 +39,8 @@ jobs:
env:
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Clone
@@ -96,8 +96,8 @@ jobs:
env:
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Clone
+2 -2
View File
@@ -266,8 +266,8 @@ jobs:
env:
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Clone
+58 -11
View File
@@ -446,8 +446,8 @@ jobs:
env:
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Set OpenVINO version output
@@ -506,8 +506,11 @@ jobs:
cmake -B build/ReleaseOV -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_OPENVINO=ON \
-DHF_UI_VERSION=${{ needs.get-version.outputs.ui_version }}
cmake --build build/ReleaseOV --config Release -j $(nproc)
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
-DHF_UI_VERSION=${{ needs.get-version.outputs.ui_version }} \
${{ env.CMAKE_ARGS }}
cmake --build build/ReleaseOV --config Release --parallel
- name: ccache-clear
uses: ./.github/actions/ccache-clear
@@ -521,8 +524,26 @@ jobs:
- name: Pack artifacts
id: pack_artifacts
run: |
cp LICENSE ./build/ReleaseOV/bin/
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz --transform "s,^\.,llama-${{ steps.tag.outputs.name }}," -C ./build/ReleaseOV/bin .
dest=./build/ReleaseOV/bin
OPENVINO_ROOT=./openvino_toolkit
ov_lib="$OPENVINO_ROOT/runtime/lib/intel64"
# Bundle OpenVINO runtime libs + TBB. Binaries built with RPATH=$ORIGIN
# load these siblings without setupvars.sh / LD_LIBRARY_PATH.
cp -P "$ov_lib"/libopenvino.so* \
"$ov_lib"/libopenvino_c.so* \
"$ov_lib"/libopenvino_*_plugin.so \
"$ov_lib"/libopenvino_intel_npu_compiler*.so \
"$OPENVINO_ROOT"/runtime/3rdparty/tbb/lib/*.so* \
"$dest"
cp -P /usr/lib/x86_64-linux-gnu/libOpenCL.so.1* "$dest" 2>/dev/null || true
cp "$ov_lib"/cache.json "$dest" 2>/dev/null || true
# OpenVINO licensing
cp -r "$OPENVINO_ROOT"/docs/licensing "$dest"/openvino-licensing
cp LICENSE "$dest"
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz --transform "s,^\.,llama-${{ steps.tag.outputs.name }}," -C "$dest" .
- name: Upload artifacts
uses: actions/upload-artifact@v6
@@ -531,6 +552,9 @@ jobs:
name: llama-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz
windows-openvino:
needs: [check-release]
if: ${{ needs.check-release.outputs.should_release == 'true' }}
runs-on: windows-2022
outputs:
@@ -538,8 +562,8 @@ jobs:
env:
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
OPENVINO_VERSION_MAJOR: "2026.2"
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR: "2026.2.1"
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
steps:
- name: Set OpenVINO version output
@@ -607,7 +631,9 @@ jobs:
-A x64 ^
-DCMAKE_BUILD_TYPE=Release ^
-DGGML_OPENVINO=ON ^
-DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake
-DLLAMA_BUILD_BORINGSSL=ON ^
-DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake ^
${{ env.CMAKE_ARGS }}
cmake --build build\ReleaseOV --config Release -- /m
@@ -624,8 +650,29 @@ jobs:
id: pack_artifacts
shell: powershell
run: |
Copy-Item LICENSE .\build\ReleaseOV\bin\
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.zip .\build\ReleaseOV\bin\*
# Locate the extracted OpenVINO toolkit root (same pattern as the Build step).
$OPENVINO_ROOT = (Get-ChildItem -Directory openvino_toolkit | Select-Object -First 1).FullName
if (-not $OPENVINO_ROOT) {
Write-Error "OpenVINO toolkit folder not found under .\openvino_toolkit"
exit 1
}
$dest = ".\build\ReleaseOV\bin\Release"
$ovBin = Join-Path $OPENVINO_ROOT 'runtime\bin\intel64\Release'
Copy-Item -Path (Join-Path $ovBin '*.dll') -Destination $dest -Force
Copy-Item -Path (Join-Path $ovBin 'cache.json') -Destination $dest -Force
$tbbBin = Join-Path $OPENVINO_ROOT 'runtime\3rdparty\tbb\bin'
Copy-Item -Path (Join-Path $tbbBin 'tbb*.dll') -Destination $dest -Force
# OpenVINO licensing
$licensingDest = Join-Path $dest 'openvino-licensing'
New-Item -ItemType Directory -Force -Path $licensingDest | Out-Null
Copy-Item -Path (Join-Path $OPENVINO_ROOT 'docs\licensing\*') -Destination $licensingDest -Recurse -Force
Copy-Item LICENSE $dest
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.zip $dest\*
- name: Upload artifacts
uses: actions/upload-artifact@v6
+10
View File
@@ -222,6 +222,16 @@ if (LLAMA_BUILD_APP)
add_subdirectory(app)
endif()
# Standalone libmtmd build without pulling in the rest of the tools/ tree.
# Useful when packaging just the mtmd library for language bindings (e.g. an
# Apple XCFramework, or a WASM build). When the full tools build is enabled,
# mtmd is already built by the tools/ subdirectory above; this hook only fires
# when LLAMA_BUILD_TOOLS is OFF to avoid double-adding the target.
option(LLAMA_BUILD_MTMD "llama: build tools/mtmd library standalone" OFF)
if (LLAMA_BUILD_MTMD AND NOT (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TOOLS))
add_subdirectory(tools/mtmd)
endif()
#
# install
#
+1 -1
View File
@@ -10,7 +10,7 @@
# ggml-org/ggml-rpc : rgerganov
# ggml-org/ggml-sycl : arthw
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
# ggml-org/ggml-webgpu : reeselevine
# ggml-org/ggml-webgpu : reeselevine, yomaytk
# ggml-org/ggml-zdnn : taronaeo
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
# ggml-org/llama-mtmd : ngxson
+3 -1
View File
@@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
+1 -1
View File
@@ -80,7 +80,7 @@ To protect sensitive data from potential leaks or unauthorized access, it is cru
### Untrusted environments or networks
If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:
* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
* Do not use the RPC backend, [ggml-rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value.
* Encrypt your data if sending it over the network.
+1 -1
View File
@@ -1,6 +1,6 @@
set(TARGET llama-app)
add_executable(${TARGET} llama.cpp)
add_executable(${TARGET} llama.cpp download.cpp)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama)
target_link_libraries(${TARGET} PRIVATE
+71
View File
@@ -0,0 +1,71 @@
#include "arg.h"
#include "common.h"
#include "download.h"
#include "log.h"
#include <cstdio>
#include <filesystem>
static void print_usage(int /*argc*/, char ** argv) {
printf(
"\nexamples:\n"
" %s -hf ggml-org/gemma-3-4b-it-qat-GGUF\n"
" %s -hf ggml-org/gemma-3-4b-it-qat-GGUF:Q4_K_M\n"
" %s -hf ggml-org/models -hff model.gguf\n"
" %s -mu https://example.com/model.gguf -m model.gguf\n"
"\n",
argv[0], argv[0], argv[0], argv[0]
);
}
int llama_download(int argc, char ** argv);
int llama_download(int argc, char ** argv) {
common_init();
common_params params;
params.verbosity = LOG_LEVEL_ERROR;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DOWNLOAD, print_usage)) {
return 1;
}
const bool has_source = !params.model.hf_repo.empty() || !params.model.url.empty() ||
!params.model.path.empty() || !params.model.docker_repo.empty();
if (!has_source) {
fprintf(stderr, "error: no model source specified (use --hf-repo, --model-url, --model or --docker-repo)\n");
return 1;
}
try {
common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_DOWNLOAD);
common_models_handler_apply(handler, params);
} catch (const std::exception & e) {
fprintf(stderr, "error: %s\n", e.what());
return 1;
}
if (!params.models_preset.empty()) {
// -hf pointed at a preset repo: print the preset path and stop
printf("%s\n", params.models_preset.c_str());
return 0;
}
if (params.model.path.empty()) {
fprintf(stderr, "error: model download failed\n");
return 1;
}
if (!std::filesystem::exists(params.model.path)) {
fprintf(stderr, "error: model file does not exist: %s\n", params.model.path.c_str());
return 1;
}
printf("%s\n", params.model.path.c_str());
if (!params.mmproj.path.empty()) {
printf("%s\n", params.mmproj.path.c_str());
}
if (!params.speculative.draft.mparams.path.empty()) {
printf("%s\n", params.speculative.draft.mparams.path.c_str());
}
return 0;
}
+10 -4
View File
@@ -19,6 +19,7 @@ int llama_batched_bench(int argc, char ** argv);
int llama_fit_params(int argc, char ** argv);
int llama_quantize(int argc, char ** argv);
int llama_perplexity(int argc, char ** argv);
int llama_download(int argc, char ** argv);
// Self-update is only supported for binaries built with llama-install.sh
static int llama_update(int argc, char ** argv) {
@@ -49,6 +50,7 @@ struct command {
std::vector<std::string> aliases;
bool hidden;
int (*func)(int, char **);
bool flags = false; // allow --name
};
#ifdef LLAMA_INSTALL_BUILD
@@ -61,15 +63,16 @@ static const command cmds[] = {
{"serve", "HTTP API server", {"server"}, false, llama_server },
{"cli", "Command-line interactive interface", {"client"}, false, llama_cli },
{"update", "Update llama to the latest release", {}, UPDATE_HIDDEN, llama_update },
{"download", "Download a model", {"get"}, false, llama_download },
{"completion", "Text completion", {"complete"}, true, llama_completion },
{"bench", "Benchmark prompt processing and text generation", {}, true, llama_bench },
{"batched-bench", "Benchmark batched decoding performance", {}, true, llama_batched_bench},
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
{"quantize", "Quantize a model", {}, true, llama_quantize },
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
{"version", "Show version", {}, false, version },
{"licenses", "Show third-party licenses", {"credits"}, false, licenses },
{"help", "Show available commands", {}, false, help },
{"version", "Show version", {}, false, version, true },
{"licenses", "Show third-party licenses", {"credits"}, false, licenses, true },
{"help", "Show available commands", {}, false, help, true },
};
#undef UPDATE_HIDDEN
@@ -106,7 +109,10 @@ static int help(int argc, char ** argv) {
return 0;
}
static bool matches(const std::string & arg, const command & cmd) {
static bool matches(std::string arg, const command & cmd) {
if (cmd.flags && arg.size() > 2 && arg[0] == '-' && arg[1] == '-') {
arg.erase(0, 2);
}
if (arg == cmd.name) {
return true;
}
+11
View File
@@ -13,6 +13,7 @@ LLAMA_BUILD_EXAMPLES=OFF
LLAMA_BUILD_TOOLS=OFF
LLAMA_BUILD_TESTS=OFF
LLAMA_BUILD_SERVER=OFF
LLAMA_BUILD_MTMD=ON
GGML_METAL=ON
GGML_METAL_EMBED_LIBRARY=ON
GGML_BLAS_DEFAULT=ON
@@ -39,6 +40,7 @@ COMMON_CMAKE_ARGS=(
-DLLAMA_BUILD_TOOLS=${LLAMA_BUILD_TOOLS}
-DLLAMA_BUILD_TESTS=${LLAMA_BUILD_TESTS}
-DLLAMA_BUILD_SERVER=${LLAMA_BUILD_SERVER}
-DLLAMA_BUILD_MTMD=${LLAMA_BUILD_MTMD}
-DGGML_METAL_EMBED_LIBRARY=${GGML_METAL_EMBED_LIBRARY}
-DGGML_BLAS_DEFAULT=${GGML_BLAS_DEFAULT}
-DGGML_METAL=${GGML_METAL}
@@ -126,6 +128,8 @@ setup_framework_structure() {
cp ggml/include/ggml-cpu.h ${header_path}
cp ggml/include/ggml-blas.h ${header_path}
cp ggml/include/gguf.h ${header_path}
cp tools/mtmd/mtmd.h ${header_path}
cp tools/mtmd/mtmd-helper.h ${header_path}
# Create module map (common for all platforms)
cat > ${module_path}module.modulemap << EOF
@@ -247,6 +251,7 @@ combine_static_libraries() {
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-cpu.a"
"${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a"
"${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a"
"${base_dir}/${build_dir}/tools/mtmd/${release_dir}/libmtmd.a"
)
# Create temporary directory for processing
@@ -410,6 +415,7 @@ cmake -B build-ios-sim -G Xcode \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DMTMD_VIDEO=OFF \
-S .
cmake --build build-ios-sim --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
@@ -424,6 +430,7 @@ cmake -B build-ios-device -G Xcode \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DMTMD_VIDEO=OFF \
-S .
cmake --build build-ios-device --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
@@ -450,6 +457,7 @@ cmake -B build-visionos -G Xcode \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DMTMD_VIDEO=OFF \
-S .
cmake --build build-visionos --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
@@ -465,6 +473,7 @@ cmake -B build-visionos-sim -G Xcode \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DMTMD_VIDEO=OFF \
-S .
cmake --build build-visionos-sim --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
@@ -481,6 +490,7 @@ cmake -B build-tvos-sim -G Xcode \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DMTMD_VIDEO=OFF \
-S .
cmake --build build-tvos-sim --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
@@ -496,6 +506,7 @@ cmake -B build-tvos-device -G Xcode \
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
-DLLAMA_OPENSSL=OFF \
-DMTMD_VIDEO=OFF \
-S .
cmake --build build-tvos-device --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
-2
View File
@@ -80,8 +80,6 @@ add_library(${TARGET}
http.h
imatrix-loader.cpp
imatrix-loader.h
json-partial.cpp
json-partial.h
json-schema-to-grammar.cpp
llguidance.cpp
log.cpp
+256 -120
View File
@@ -297,58 +297,6 @@ struct handle_model_result {
std::string preset_path;
};
static handle_model_result common_params_handle_model(struct common_params_model & model,
const common_download_opts & opts) {
handle_model_result result;
if (!model.docker_repo.empty()) {
model.path = common_docker_resolve_model(model.docker_repo);
} else if (!model.hf_repo.empty()) {
// If -m was used with -hf, treat the model "path" as the hf_file to download
if (model.hf_file.empty() && !model.path.empty()) {
model.hf_file = model.path;
model.path = "";
}
common_download_opts hf_opts = opts;
auto download_result = common_download_model(model, hf_opts);
if (!download_result.preset_path.empty()) {
result.found_preset = true;
result.preset_path = download_result.preset_path;
return result; // skip everything else if preset.ini is used
}
if (download_result.model_path.empty()) {
throw std::runtime_error("failed to download model from Hugging Face");
}
model.path = download_result.model_path;
if (!download_result.mmproj_path.empty()) {
result.found_mmproj = true;
result.mmproj.path = download_result.mmproj_path;
}
if (!download_result.mtp_path.empty()) {
result.found_mtp = true;
result.mtp.path = download_result.mtp_path;
}
} else if (!model.url.empty()) {
if (model.path.empty()) {
auto f = string_split<std::string>(model.url, '#').front();
f = string_split<std::string>(f, '?').front();
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
auto download_result = common_download_model(model, opts);
if (download_result.model_path.empty()) {
throw std::runtime_error("failed to download model from " + model.url);
}
}
return result;
}
const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_F32,
GGML_TYPE_F16,
@@ -393,72 +341,241 @@ static bool parse_bool_value(const std::string & value) {
}
//
// CLI argument parsing functions
// common_models_handler
//
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
static std::string get_default_local_path(const std::string & url) {
auto f = string_split<std::string>(url, '#').front();
f = string_split<std::string>(f, '?').front();
return fs_get_cache_file(string_split<std::string>(f, '/').back());
}
common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex) {
common_download_hf_plan plan;
common_download_hf_plan plan_spec;
common_download_hf_plan plan_voc;
common_download_opts opts;
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
params.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
// only download mmproj if the current example is using it
bool use_mmproj = false;
for (const auto & ex : mmproj_examples) {
if (curr_ex == ex) {
use_mmproj = true;
break;
}
}
opts.bearer_token = params.hf_token;
opts.offline = params.offline;
opts.skip_download = params.skip_download;
opts.download_mtp = spec_type_draft_mtp;
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
opts.download_mmproj = use_mmproj && !params.no_mmproj
&& params.mmproj.path.empty() && params.mmproj.url.empty();
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
// so we should not auto-discover mtp/mmproj siblings for them
common_download_opts sub_opts = opts;
sub_opts.download_mtp = false;
sub_opts.download_mmproj = false;
if (!params.model.hf_repo.empty()) {
plan = common_download_get_hf_plan(params.model, opts);
}
try {
auto res = common_params_handle_model(params.model, opts);
if (res.found_preset) {
if (!params.models_preset.empty()) {
throw std::invalid_argument("cannot use both --models-preset and -hf with a preset.ini file");
if (!params.speculative.draft.mparams.hf_repo.empty()) {
plan_spec = common_download_get_hf_plan(params.speculative.draft.mparams, opts);
}
if (!params.vocoder.model.hf_repo.empty()) {
plan_voc = common_download_get_hf_plan(params.vocoder.model, opts);
}
return common_models_handler{plan, plan_spec, plan_voc, opts};
}
bool common_models_handler_is_preset_repo(const common_models_handler & handler) {
return !handler.plan.preset.url.empty();
}
static std::vector<common_download_task> build_url_tasks(const common_params_model & model, common_download_opts opts) {
auto parts = common_download_get_all_parts(model.url);
std::vector<common_download_task> tasks;
// single-part: download straight to model.path if the user gave one (-m), else the cache default
if (parts.size() == 1) {
common_download_task task;
task.url = parts[0];
task.local_path = model.path.empty() ? get_default_local_path(parts[0]) : model.path;
task.opts = opts;
tasks.push_back(std::move(task));
return tasks;
}
// multi-part: place each part under the user's -m directory (if given), else the cache default
std::string base_dir;
if (!model.path.empty()) {
auto pos = model.path.rfind('/');
base_dir = pos == std::string::npos ? std::string(".") : model.path.substr(0, pos);
}
for (const auto & part : parts) {
common_download_task task;
task.url = part;
task.opts = opts;
std::string local = get_default_local_path(part);
if (!base_dir.empty()) {
auto pos = local.rfind('/');
std::string name = pos == std::string::npos ? local : local.substr(pos + 1);
local = base_dir + "/" + name;
}
task.local_path = local;
tasks.push_back(std::move(task));
}
return tasks;
}
void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback) {
std::vector<common_download_task> tasks;
auto & plan = handler.plan;
auto & plan_spec = handler.plan_spec;
auto & plan_voc = handler.plan_voc;
auto opts = handler.opts; // copy
opts.callback = callback;
// handle plain "url" if needed
auto handle_url = [&](common_params_model & model) {
if (!model.url.empty()) {
if (model.path.empty()) {
model.path = get_default_local_path(model.url);
}
}
};
handle_url(params.model);
handle_url(params.mmproj);
handle_url(params.vocoder.model);
handle_url(params.speculative.draft.mparams);
// optionally, if docker repo is set, resolve it
if (!params.model.docker_repo.empty()) {
params.model.url = common_docker_resolve_model(params.model.docker_repo);
params.model.path = get_default_local_path(params.model.url);
}
// handle plain "url" tasks (non-hf)
if (!params.model.url.empty()) {
auto url_tasks = build_url_tasks(params.model, opts);
// the first part is what gets loaded, so point params.model.path at it
if (!url_tasks.empty()) {
std::string first_path = url_tasks.front().local_path;
url_tasks.front().on_done = [&]() { params.model.path = first_path; };
}
for (auto & task : url_tasks) {
tasks.push_back(std::move(task));
}
}
if (!params.mmproj.url.empty()) {
common_download_task task;
task.url = params.mmproj.url;
task.local_path = params.mmproj.path;
task.opts = opts;
tasks.push_back(task);
}
if (!params.vocoder.model.url.empty()) {
common_download_task task;
task.url = params.vocoder.model.url;
task.local_path = params.vocoder.model.path;
task.opts = opts;
tasks.push_back(task);
}
if (!params.speculative.draft.mparams.url.empty()) {
common_download_task task;
task.url = params.speculative.draft.mparams.url;
task.local_path = params.speculative.draft.mparams.path;
task.opts = opts;
tasks.push_back(task);
}
// handle hf_plan tasks
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files, common_params_model & model) {
for (size_t i = 0; i < model_files.size(); ++i) {
auto & model_file = model_files[i];
bool is_first = (i == 0);
tasks.emplace_back(model_file, opts, [&, is_first]() {
if (is_first) {
// only use first part as model path
model.path = hf_cache::finalize_file(model_file);
} else {
hf_cache::finalize_file(model_file);
}
});
}
};
if (!plan.model_files.empty()) {
add_tasks(plan.model_files, params.model);
}
if (!plan.mmproj.local_path.empty()) {
tasks.emplace_back(plan.mmproj, opts, [&]() {
params.mmproj.path = hf_cache::finalize_file(plan.mmproj);
});
}
if (!plan.mtp.local_path.empty()) {
tasks.emplace_back(plan.mtp, opts, [&]() {
// only fall back to the discovered MTP head when no draft was explicitly provided
if (params.speculative.draft.mparams.empty()) {
params.speculative.draft.mparams.path = hf_cache::finalize_file(plan.mtp);
} else {
hf_cache::finalize_file(plan.mtp);
}
});
}
if (!plan.preset.local_path.empty()) {
tasks.emplace_back(plan.preset, opts, [&]() {
// if HF repo is a preset repo, we simply run server in router mode with the preset.ini file
params.models_preset_hf = params.model.hf_repo; // only for showing a warning
params.models_preset = res.preset_path;
params.models_preset = hf_cache::finalize_file(plan.preset);
params.model = common_params_model{}; // make sure to clear model, so server starts in router mode
return true;
}
});
}
if (params.no_mmproj) {
params.mmproj = {};
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
// optionally, handle mmproj model when -hf is specified
params.mmproj = res.mmproj;
}
// only download mmproj if the current example is using it
for (const auto & ex : mmproj_examples) {
if (curr_ex == ex) {
common_params_handle_model(params.mmproj, sub_opts);
break;
// handle plan_spec (e.g. --spec-draft-hf)
if (!plan_spec.model_files.empty()) {
add_tasks(plan_spec.model_files, params.speculative.draft.mparams);
}
// handle vocoder plan (e.g. --hf-repo-v)
if (!plan_voc.model_files.empty()) {
add_tasks(plan_voc.model_files, params.vocoder.model);
}
// run all tasks in parallel
if (!params.offline) {
// if duplicated files are found, only download once (but still call on_done for each task)
std::unordered_map<std::string, common_download_task *> unique_tasks;
for (auto & task : tasks) {
auto it = unique_tasks.find(task.local_path);
if (it == unique_tasks.end()) {
unique_tasks[task.local_path] = &task;
}
}
// when --spec-type mtp is set and no draft model was provided explicitly,
// fall back to the MTP head discovered alongside the -hf model
if (spec_type_draft_mtp && res.found_mtp &&
params.speculative.draft.mparams.path.empty() &&
params.speculative.draft.mparams.hf_repo.empty() &&
params.speculative.draft.mparams.url.empty()) {
params.speculative.draft.mparams.path = res.mtp.path;
std::vector<common_download_task> unique_tasks_vec;
for (auto & pair : unique_tasks) {
unique_tasks_vec.push_back(*pair.second);
}
common_download_run_tasks(unique_tasks_vec);
}
// download successful, update params with the downloaded paths
for (const auto & task : tasks) {
if (task.on_done) {
task.on_done();
}
common_params_handle_model(params.speculative.draft.mparams, sub_opts);
common_params_handle_model(params.vocoder.model, sub_opts);
return true;
} catch (const common_skip_download_exception &) {
return false;
} catch (const std::exception &) {
throw;
}
}
//
// CLI argument parsing functions
//
static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) {
common_params & params = ctx_arg.params;
@@ -584,17 +701,22 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
}
// export_graph_ops loads only metadata
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
const bool skip_model_download =
// server will call common_params_handle_models() later, so we skip it here
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
// download calls common_params_handle_models() itself and prints the paths
ctx_arg.ex == LLAMA_EXAMPLE_DOWNLOAD ||
// export_graph_ops loads only metadata
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
if (!skip_model_download) {
// handle model and download
common_params_handle_models(params, ctx_arg.ex);
common_models_handler handler = common_models_handler_init(params, ctx_arg.ex);
common_models_handler_apply(handler, params);
// model is required (except for server)
// TODO @ngxson : maybe show a list of available models in CLI in this case
if (params.model.path.empty()
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
&& !params.usage
&& !params.completion) {
throw std::invalid_argument("error: --model is required\n");
@@ -662,15 +784,19 @@ static void common_params_print_usage(common_params_context & ctx_arg) {
common_options.push_back(&opt);
}
}
printf("----- common params -----\n\n");
print_options(common_options);
printf("\n\n----- sampling params -----\n\n");
print_options(sampling_options);
printf("\n\n----- speculative params -----\n\n");
print_options(spec_options);
// TODO: maybe convert enum llama_example to string
printf("\n\n----- example-specific params -----\n\n");
print_options(specific_options);
bool first = true;
auto print_section = [&](const char * header, std::vector<common_arg *> & options) {
if (options.empty()) {
return;
}
printf("%s----- %s -----\n\n", first ? "" : "\n\n", header);
first = false;
print_options(options);
};
print_section("common params", common_options);
print_section("sampling params", sampling_options);
print_section("speculative params", spec_options);
print_section("example-specific params", specific_options);
}
static void common_params_print_completion(common_params_context & ctx_arg) {
@@ -1070,7 +1196,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
* - if both {LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_*,} are set, we will prioritize the LLAMA_EXAMPLE_* matching current example
*/
auto add_opt = [&](common_arg arg) {
if ((arg.in_example(ex) || arg.in_example(LLAMA_EXAMPLE_COMMON)) && !arg.is_exclude(ex)) {
// download only exposes the handful of args explicitly tagged for it
const bool inherit_common = ex != LLAMA_EXAMPLE_DOWNLOAD;
if ((arg.in_example(ex) || (inherit_common && arg.in_example(LLAMA_EXAMPLE_COMMON))) && !arg.is_exclude(ex)) {
ctx_arg.options.push_back(std::move(arg));
}
};
@@ -1081,7 +1209,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.usage = true;
}
));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}));
add_opt(common_arg(
{"--version"},
"show version and build info",
@@ -2203,7 +2331,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) {
params.no_mmproj = !value;
}
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_AUTO"));
).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MMPROJ_AUTO"));
add_opt(common_arg(
{"--mmproj-offload"},
{"--no-mmproj-offload"},
@@ -2602,14 +2730,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.model.path = value;
}
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL"));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MODEL"));
add_opt(common_arg(
{"-mu", "--model-url"}, "MODEL_URL",
"model download url (default: unused)",
[](common_params & params, const std::string & value) {
params.model.url = value;
}
).set_env("LLAMA_ARG_MODEL_URL"));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MODEL_URL"));
add_opt(common_arg(
{ "-dr", "--docker-repo" }, "[<repo>/]<model>[:quant]",
"Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n"
@@ -2618,7 +2746,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.model.docker_repo = value;
}
).set_env("LLAMA_ARG_DOCKER_REPO"));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_DOCKER_REPO"));
add_opt(common_arg(
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
@@ -2628,14 +2756,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.model.hf_repo = value;
}
).set_env("LLAMA_ARG_HF_REPO"));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_HF_REPO"));
add_opt(common_arg(
{"-hff", "--hf-file"}, "FILE",
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
[](common_params & params, const std::string & value) {
params.model.hf_file = value;
}
).set_env("LLAMA_ARG_HF_FILE"));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_HF_FILE"));
add_opt(common_arg(
{"-hfv", "-hfrv", "--hf-repo-v"}, "<user>/<model>[:quant]",
"Hugging Face model repository for the vocoder model (default: unused)",
@@ -2656,7 +2784,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.hf_token = value;
}
).set_env("HF_TOKEN"));
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("HF_TOKEN"));
add_opt(common_arg(
{"--mtp"},
"also download the multi-token prediction (MTP) head, if available (default: unused)",
[](common_params & params) {
params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_DRAFT_MTP);
}
).set_examples({LLAMA_EXAMPLE_DOWNLOAD}));
add_opt(common_arg(
{"--context-file"}, "FNAME",
"file to load context from (use comma-separated values to specify multiple files)",
@@ -3613,6 +3748,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"draft model for speculative decoding (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.draft.mparams.path = value;
params.speculative.draft.mparams.hf_file = value; // will be used if --spec-draft-hf is set
}
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
add_opt(common_arg(
+17 -5
View File
@@ -1,12 +1,14 @@
#pragma once
#include "common.h"
#include "download.h"
#include <set>
#include <map>
#include <string>
#include <vector>
#include <cstring>
#include <memory>
// pseudo-env variable to identify preset-only arguments
#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP"
@@ -129,11 +131,21 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
// see: https://github.com/ggml-org/llama.cpp/issues/18163
void common_params_add_preset_options(std::vector<common_arg> & args);
// populate model paths (main model, mmproj, etc) from -hf if necessary
// return true if the model is ready to use
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
bool common_params_handle_models(common_params & params, llama_example curr_ex);
struct common_models_handler {
common_download_hf_plan plan;
common_download_hf_plan plan_spec;
common_download_hf_plan plan_voc;
common_download_opts opts;
};
// initialize downloading opts and hf_plan if needed, but does not download anything yet
common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex);
// check if the model is a preset repo (i.e. has a preset file)
bool common_models_handler_is_preset_repo(const common_models_handler & handler);
// download and update params with the downloaded model path
void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback = nullptr);
// initialize argument parser context - used by test-arg-parser and preset
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
+5 -4
View File
@@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
arguments.name_suffix) +
arguments.value_prefix +
(schema_info.resolves_to_string(param_schema) ?
p.tool_arg_string_value(until_suffix) :
p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) +
p.tool_arg_close(p.literal(arguments.value_suffix)));
p.ac(p.tool_arg_string_value(until_suffix) +
p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) :
(p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.tool_arg_close(p.literal(arguments.value_suffix)))));
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
if (is_required) {
+107 -53
View File
@@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
return text;
}
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
if (delims.empty() || prompt.empty()) {
return {};
common_chat_role common_chat_role_from_string(const std::string & role) {
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
return COMMON_CHAT_ROLE_UNKNOWN;
}
const char * common_chat_role_to_string(common_chat_role role) {
switch (role) {
case COMMON_CHAT_ROLE_SYSTEM: return "system";
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
case COMMON_CHAT_ROLE_USER: return "user";
case COMMON_CHAT_ROLE_TOOL: return "tool";
case COMMON_CHAT_ROLE_UNKNOWN: return "";
}
return "";
}
json common_chat_msg_delimiters::to_json() const {
json result = json::array();
for (const auto & d : delimiters) {
result.push_back({
{ "role", common_chat_role_to_string(d.role) },
{ "delimiter", d.delimiter },
});
}
return result;
}
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
common_chat_msg_delimiters result;
if (!delimiters.is_array()) {
return result;
}
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
std::vector<std::string> all_delims;
std::vector<common_peg_parser> tagged_messages;
all_delims.reserve(delims.size());
tagged_messages.reserve(delims.size());
for (const auto & d : delims) {
all_delims.push_back(d.delimiter);
result.delimiters.reserve(delimiters.size());
for (const auto & d : delimiters) {
if (!d.is_object()) {
continue;
}
auto any_delim = p.until_one_of(all_delims);
for (const auto & d : delims) {
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
}
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
});
common_peg_parse_context ctx(prompt);
const auto result = parser.parse(ctx);
if (!result.success()) {
return {};
result.delimiters.push_back({
common_chat_role_from_string(d.value("role", std::string())),
d.value("delimiter", std::string()),
});
}
std::vector<common_chat_msg_span> spans;
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
if (!node.tag.empty()) {
spans.push_back({ node.tag, node.start, node.end - node.start });
return result;
}
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
for (auto & d : delimiters) {
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
}
}
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
std::vector<std::pair<common_chat_role, size_t>> matches;
auto skip = skips.begin();
for (size_t i = 0; i < tokens.size();) {
if (skip != skips.end() && i == skip->first) {
i += skip->second;
++skip;
continue;
}
});
for (const auto & d : delimiters) {
if (i + d.tokens.size() > tokens.size()) {
continue;
}
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
matches.emplace_back(d.role, i);
break;
}
}
i++;
}
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
common_chat_msg_spans spans;
for (size_t i = 0; i + 1 < matches.size(); i++) {
const auto & curr = matches[i];
const auto & next = matches[i + 1];
spans.add(curr.first, curr.second, next.second - curr.second);
}
return spans;
}
@@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
data.prompt = prompt;
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
data.message_spans = common_chat_split_by_role(prompt, {
{ "assistant", "<|start|>assistant" },
{ "user", "<|start|>user" },
{ "system", "<|start|>developer" },
{ "system", "<|start|>system" },
{ "tool", "<|start|>functions" },
});
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
};
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;
@@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
data.prompt += data.generation_prompt;
}
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "user", "<|turn>user\n" },
{ "assistant", "<|turn>model\n" },
});
data.message_delimiters = {
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
};
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
data.supports_thinking = true;
@@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
RESULT_START, RESULT_END,
};
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
// Declare per-role message delimiters. Tool results are rendered with the
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
data.message_spans = common_chat_split_by_role(data.prompt, {
{ "assistant", GEN_PREFIX },
{ "user", TURN_START + USER },
{ "tool", TURN_START + SYSTEM + RESULT_START },
{ "system", TURN_START + SYSTEM },
});
data.message_delimiters = {
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
};
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
@@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
std::vector<common_chat_msg_delimiter> delimiters;
common_chat_msg_delimiters delimiters;
if (!autoparser.assistant_start.empty()) {
delimiters.push_back({ "assistant", autoparser.assistant_start });
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
}
if (!autoparser.user_start.empty()) {
delimiters.push_back({ "user", autoparser.user_start });
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
}
if (!delimiters.empty()) {
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
}
auto_params.message_delimiters = std::move(delimiters);
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
@@ -2708,5 +2758,9 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
GGML_ASSERT(chat_templates != nullptr);
GGML_ASSERT(chat_templates->template_default != nullptr);
if (chat_templates->template_tool_use != nullptr) {
// take the more expressive template when available
return chat_templates->template_tool_use->caps.to_map();
}
return chat_templates->template_default->caps.to_map();
}
+65 -6
View File
@@ -143,15 +143,75 @@ struct common_chat_msg_diff {
}
};
enum common_chat_role {
COMMON_CHAT_ROLE_UNKNOWN,
COMMON_CHAT_ROLE_SYSTEM,
COMMON_CHAT_ROLE_ASSISTANT,
COMMON_CHAT_ROLE_USER,
COMMON_CHAT_ROLE_TOOL
};
common_chat_role common_chat_role_from_string(const std::string & role);
const char * common_chat_role_to_string(common_chat_role role);
struct common_chat_msg_span {
std::string role;
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::size_t pos = 0;
std::size_t len = 0;
bool valid() const {
return role != COMMON_CHAT_ROLE_UNKNOWN;
}
};
struct common_chat_msg_spans {
std::vector<common_chat_msg_span> spans;
void add(common_chat_role role, size_t pos, size_t len) {
spans.push_back({ role, pos, len });
}
bool is_user_start(int32_t pos) const {
for (auto it = spans.begin(); it != spans.end(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
return true;
}
}
return false;
}
int32_t last_user_message_pos() const {
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
if (it->role == COMMON_CHAT_ROLE_USER) {
return (int32_t) it->pos;
}
}
return -1;
}
};
struct common_chat_msg_delimiter {
std::string role;
std::string delimiter;
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
std::string delimiter;
llama_tokens tokens = {};
};
struct common_chat_msg_delimiters {
std::vector<common_chat_msg_delimiter> delimiters;
common_chat_msg_delimiters() = default;
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
void add(common_chat_role role, const std::string & delimiter) {
delimiters.push_back({ role, delimiter });
}
void tokenize(const llama_vocab * vocab);
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
nlohmann::ordered_json to_json() const;
};
struct common_chat_tool {
@@ -219,7 +279,7 @@ struct common_chat_params {
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
std::string parser;
std::vector<common_chat_msg_span> message_spans;
common_chat_msg_delimiters message_delimiters;
};
// per-message parsing syntax
@@ -325,5 +385,4 @@ struct common_chat_prompt_preset {
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
+13 -9
View File
@@ -96,6 +96,7 @@ enum llama_example {
LLAMA_EXAMPLE_FIT_PARAMS,
LLAMA_EXAMPLE_RESULTS,
LLAMA_EXAMPLE_EXPORT_GRAPH_OPS,
LLAMA_EXAMPLE_DOWNLOAD,
LLAMA_EXAMPLE_COUNT,
};
@@ -290,13 +291,13 @@ struct common_params_sampling {
};
struct common_params_model {
std::string path = ""; // model local path // NOLINT
std::string url = ""; // model url to download // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string docker_repo = ""; // Docker repo // NOLINT
std::string path = ""; // model local path
std::string url = ""; // model url to download
std::string hf_repo = ""; // HF repo
std::string hf_file = ""; // HF file
std::string docker_repo = ""; // Docker repo
std::string get_name() {
std::string get_name() const {
if (!hf_repo.empty()) {
return hf_repo;
}
@@ -305,6 +306,10 @@ struct common_params_model {
}
return path;
}
bool empty() const {
return get_name().empty();
}
};
// draft-model-based speculative decoding parameters
@@ -367,7 +372,7 @@ struct common_params_speculative {
common_params_speculative_ngram_cache ngram_cache;
bool has_dft() const {
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
return !draft.mparams.empty();
}
uint32_t need_n_rs_seq() const {
@@ -519,7 +524,6 @@ struct common_params {
int32_t control_vector_layer_start = -1; // layer range for control vector
int32_t control_vector_layer_end = -1; // layer range for control vector
bool offline = false;
bool skip_download = false; // skip model file downloading
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
@@ -609,7 +613,7 @@ struct common_params {
bool cache_prompt = true; // whether to enable prompt caching
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
std::string hostname = "127.0.0.1";
+22 -116
View File
@@ -292,10 +292,6 @@ static int common_download_file_single_online(const std::string & url,
const bool file_exists = std::filesystem::exists(path);
if (!file_exists && opts.skip_download) {
return -2; // file is missing and download is disabled
}
if (file_exists && skip_etag) {
LOG_DBG("%s: using cached file: %s\n", __func__, path.c_str());
return 304; // 304 Not Modified - fake cached response
@@ -362,9 +358,6 @@ static int common_download_file_single_online(const std::string & url,
return 304; // 304 Not Modified - fake cached response
}
// pass this point, the file exists but is different from the server version, so we need to redownload it
if (opts.skip_download) {
return -2; // special code to indicate that the download was skipped due to etag mismatch
}
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return -1;
@@ -691,19 +684,8 @@ static void list_available_gguf_files(const hf_cache::hf_files & files) {
}
}
struct hf_plan {
hf_cache::hf_file primary;
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
hf_cache::hf_file mtp;
hf_cache::hf_file preset; // if set, only this file is downloaded
};
static hf_plan get_hf_plan(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj,
bool download_mtp) {
hf_plan plan;
common_download_hf_plan common_download_get_hf_plan(const common_params_model & model, const common_download_opts & opts) {
common_download_hf_plan plan;
hf_cache::hf_files all;
auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
@@ -752,125 +734,49 @@ static hf_plan get_hf_plan(const common_params_model & model,
plan.primary = primary;
plan.model_files = get_split_files(all, primary);
if (download_mmproj) {
if (opts.download_mmproj) {
plan.mmproj = find_best_mmproj(all, primary.path);
}
if (download_mtp) {
if (opts.download_mtp) {
plan.mtp = find_best_mtp(all, primary.path);
}
return plan;
}
struct download_task {
std::string url;
std::string path;
};
static std::vector<download_task> get_url_tasks(const common_params_model & model) {
auto split = get_gguf_split_info(model.url);
if (split.count <= 1) {
return {{model.url, model.path}};
}
auto filename = split.prefix;
if (auto pos = split.prefix.rfind('/'); pos != std::string::npos) {
filename = split.prefix.substr(pos + 1);
}
auto parent_path = std::filesystem::path(model.path).parent_path();
auto prefix_path = (parent_path / filename).string();
std::vector<download_task> tasks;
for (int i = 1; i <= split.count; i++) {
auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count);
tasks.push_back({split.prefix + suffix, prefix_path + suffix});
}
return tasks;
}
common_download_model_result common_download_model(const common_params_model & model,
const common_download_opts & opts) {
common_download_model_result result;
std::vector<download_task> tasks;
hf_plan hf;
bool download_mmproj = opts.download_mmproj;
bool download_mtp = opts.download_mtp;
bool is_hf = !model.hf_repo.empty();
if (is_hf) {
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
if (!hf.preset.path.empty()) {
// if preset.ini exists, only download that file alone
tasks.push_back({hf.preset.url, hf.preset.local_path});
} else {
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
if (!hf.mmproj.path.empty()) {
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
}
if (!hf.mtp.path.empty()) {
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
}
}
} else if (!model.url.empty()) {
tasks = get_url_tasks(model);
} else {
result.model_path = model.path;
return result;
}
if (tasks.empty()) {
return result;
}
void common_download_run_tasks(const std::vector<common_download_task> & tasks) {
std::vector<std::future<int>> futures;
for (const auto & task : tasks) {
futures.push_back(std::async(std::launch::async,
[&task, &opts, is_hf]() {
return common_download_file_single(task.url, task.path, opts, is_hf);
[&task]() {
return common_download_file_single(task.url, task.local_path, task.opts, task.is_hf);
}
));
}
for (auto & f : futures) {
int status = f.get();
if (status == -2 && opts.skip_download) {
throw common_skip_download_exception();
}
for (size_t i = 0; i < futures.size(); ++i) {
std::string url = tasks[i].url;
int status = futures[i].get();
bool is_ok = is_http_status_ok(status);
if (!is_ok) {
return {};
throw std::runtime_error(string_format("Download '%s' failed with status code: %d", url.c_str(), status));
}
}
}
if (is_hf) {
if (!hf.preset.path.empty()) {
// if preset.ini is used, do not set other paths
result.preset_path = hf_cache::finalize_file(hf.preset);
} else {
for (const auto & f : hf.model_files) {
hf_cache::finalize_file(f);
}
result.model_path = hf.primary.final_path;
std::vector<std::string> common_download_get_all_parts(const std::string & url) {
auto split = get_gguf_split_info(url);
if (!hf.mmproj.path.empty()) {
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
}
if (!hf.mtp.path.empty()) {
result.mtp_path = hf_cache::finalize_file(hf.mtp);
}
}
} else {
result.model_path = model.path;
if (split.count <= 1) {
return {url};
}
return result;
std::vector<std::string> parts;
for (int i = 1; i <= split.count; i++) {
auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count);
parts.push_back(split.prefix + suffix);
}
return parts;
}
//
+28 -42
View File
@@ -1,7 +1,10 @@
#pragma once
#include "hf-cache.h"
#include <string>
#include <vector>
#include <functional>
struct common_params_model;
@@ -47,66 +50,40 @@ struct common_cached_model_info {
}
};
// Options for common_download_model and common_download_file_single
// Options for common_download_file_single
struct common_download_opts {
std::string bearer_token;
common_header_list headers;
bool offline = false;
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
bool download_mmproj = false;
bool download_mtp = false;
common_download_callback * callback = nullptr;
};
// Result of common_download_model
struct common_download_model_result {
std::string model_path;
std::string mmproj_path;
std::string mtp_path;
std::string preset_path;
struct common_download_task {
common_download_opts opts;
std::string url;
std::string local_path;
std::function<void()> on_done;
bool is_hf = false;
common_download_task() = default;
common_download_task(hf_cache::hf_file f,
const common_download_opts & opts,
std::function<void()> on_done = nullptr)
: opts(opts), url(f.url), local_path(f.local_path), on_done(on_done), is_hf(true) {}
};
// throw if the file is missing or invalid (e.g. ETag check failed)
struct common_skip_download_exception : public std::runtime_error {
common_skip_download_exception() : std::runtime_error("skip download") {}
};
void common_download_run_tasks(const std::vector<common_download_task> & tasks);
// Download model from HuggingFace repo or URL
//
// input (via model struct):
// - model.hf_repo: HF repo with optional tag, see common_download_split_repo_tag
// - model.hf_file: specific file in the repo (requires hf_repo)
// - model.url: simple download (used if hf_repo is empty)
// - model.path: local file path
//
// tag matching (for HF repos without model.hf_file):
// - if tag is specified, searches for GGUF matching that quantization
// - if no tag, searches for Q4_K_M, then Q4_0, then first available GGUF
//
// split GGUF: multi-part files like "model-00001-of-00003.gguf" are automatically
// detected and all parts are downloaded
//
// caching:
// - HF repos: uses HuggingFace cache
// - URLs: uses ETag-based caching
//
// when opts.offline=true, no network requests are made
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
// then with the closest quantization bits
// when download_mtp=true, applies the same sibling search for an MTP-head GGUF
//
// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure)
common_download_model_result common_download_model(
const common_params_model & model,
const common_download_opts & opts = {}
);
// if url is a multi-part GGUF file, returns all parts, otherwise returns the single file
std::vector<std::string> common_download_get_all_parts(const std::string & url);
// returns list of cached models
std::vector<common_cached_model_info> common_list_cached_models();
// download single file from url to local path
// returns status code or -1 on error
// returns -2 if the download was skipped due to ETag mismatch (file outdated, skip_download=true)
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
int common_download_file_single(const std::string & url,
const std::string & path,
@@ -123,3 +100,12 @@ std::string common_docker_resolve_model(const std::string & docker);
// - if tag is present, removes only files matching that tag (and orphaned blobs)
// returns true if anything was removed
bool common_download_remove(const std::string & hf_repo_with_tag);
struct common_download_hf_plan {
hf_cache::hf_file primary;
hf_cache::hf_files model_files;
hf_cache::hf_file mmproj;
hf_cache::hf_file mtp;
hf_cache::hf_file preset; // if set, only this file is downloaded
};
common_download_hf_plan common_download_get_hf_plan(const common_params_model & model, const common_download_opts & opts);
-324
View File
@@ -1,324 +0,0 @@
#include "json-partial.h"
#include "log.h"
#include <nlohmann/json.hpp>
#include <string>
#include <regex>
using json = nlohmann::ordered_json;
enum common_json_stack_element_type {
COMMON_JSON_STACK_ELEMENT_OBJECT,
COMMON_JSON_STACK_ELEMENT_KEY,
COMMON_JSON_STACK_ELEMENT_ARRAY,
};
struct common_json_stack_element {
common_json_stack_element_type type;
std::string key;
};
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out)
{
std::string::const_iterator it = input.begin();
const auto end = input.end();
return common_json_parse(it, end, healing_marker, out);
}
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out)
{
// // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> {
std::size_t position;
bool found_error;
std::string last_token;
std::string exception_message;
std::vector<common_json_stack_element> stack;
json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
this->position = position - 1;
this->found_error = true;
this->last_token = last_token;
this->exception_message = ex.what();
return false;
}
void close_value() {
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
stack.pop_back();
}
}
bool null() override { // NOLINT
close_value();
return true;
}
bool boolean(bool) override { // NOLINT
close_value();
return true;
}
bool number_integer(number_integer_t) override { // NOLINT
close_value();
return true;
}
bool number_unsigned(number_unsigned_t) override { // NOLINT
close_value();
return true;
}
bool number_float(number_float_t, const string_t &) override { // NOLINT
close_value();
return true;
}
bool string(string_t &) override { // NOLINT
close_value();
return true;
}
bool binary(binary_t &) override { // NOLINT
close_value();
return true;
}
bool start_object(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
return true;
}
bool end_object() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
stack.pop_back();
close_value();
return true;
}
bool key(string_t & key) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
return true;
}
bool start_array(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
return true;
}
bool end_array() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
stack.pop_back();
close_value();
return true;
}
};
json_error_locator err_loc;
auto start = it;
json::sax_parse(it, end, &err_loc);
if (err_loc.found_error) {
it = start;
auto temptative_end = it + err_loc.position;
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
auto input = std::string(it, temptative_end);
try {
out.json = json::parse(input);
// out.json = json::parse(it, temptative_end);
it = temptative_end;
return true;
} catch (const std::exception & ex) {
// No, needs healing.
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
}
auto can_parse = [](const std::string & str) {
try {
auto _ = json::parse(str); // NOLINT
return true;
} catch (const std::exception &) {
return false;
}
};
if (!healing_marker.empty() && !err_loc.stack.empty()) {
std::string str(it, temptative_end);
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
if (last_non_sp_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
auto last_non_sp_char = str[last_non_sp_pos];
// Used to detect stops on a number, which may not be complete.
auto was_maybe_number = [&]() {
if (!str.empty() && std::isspace(str.back())) {
return false;
}
return std::isdigit(last_non_sp_char) ||
last_non_sp_char == '.' ||
last_non_sp_char == 'e' ||
last_non_sp_char == 'E' ||
last_non_sp_char == '-';
};
std::string closing;
for (size_t i = err_loc.stack.size(); i > 0; i--) {
auto & el = err_loc.stack[i - 1];
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
closing += "}";
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
closing += "]";
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
throw std::runtime_error("Unexpected stack element type");
}
}
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
auto is_high_surrogate = [&](const std::string & s) {
// Check if a partial of a high surrogate (U+D800-U+DBFF)
return s.length() >= 4 &&
s[0] == '\\' && s[1] == 'u' &&
std::tolower(s[2]) == 'd' &&
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
};
// Initialize the unicode marker to a low surrogate to handle the edge case
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
// backslash (\)
std::string unicode_marker_padding = "udc00";
std::smatch last_unicode_seq;
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
std::smatch second_last_seq;
std::string prelude = str.substr(0, last_unicode_seq.position());
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
if (is_high_surrogate(last_unicode_seq.str())) {
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
unicode_marker_padding += "\\udc00";
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
if (is_high_surrogate(second_last_seq.str())) {
// If this follows a high surrogate, pad it to be a low surrogate
if (last_unicode_seq.length() == 2) {
unicode_marker_padding = "dc00";
} else if (last_unicode_seq.length() == 3) {
unicode_marker_padding = "c00";
} else {
// The original unicode_marker_padding is already padded with 0s
}
}
}
}
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
// We're inside an object value
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
// Was about to create an object value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + ": 1" + closing)) {
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
// Was about to create an object
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an object value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an object value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an object value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else {
// find last :
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
// Cutting back to opening : for object value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
// Was about to create an array value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an array value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an array value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
// Was inside an array value string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
// Had just finished a value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
} else {
auto last_pos = str.find_last_of("[,");
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
}
// Cutting back to last [ or , for array value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\": 1" + closing)) {
// Was inside an object key string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
// Was inside an object key string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
// Was inside an object key string after a partial unicode escape
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
} else {
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "Cutting back to last : for object key+value\n");
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
out.json = json::parse(str);
it = temptative_end;
return true;
}
// handle unclosed top-level primitive
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
std::string str(it, temptative_end);
const auto & magic_seed = out.healing_marker.marker = healing_marker;
if (can_parse(str + "\"")) {
// Was inside an string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
// Was inside an string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
} else {
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
// fprintf(stderr, "Closing: TODO\n");
return false;
}
out.json = json::parse(str);
it = temptative_end;
return true;
}
return false;
}
out.json = json::parse(it, end);
it = end;
return true;
}
-39
View File
@@ -1,39 +0,0 @@
#pragma once
// TODO: use json_fwd.hpp when possible
#include <nlohmann/json.hpp>
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
struct common_healing_marker {
// Raw marker.
std::string marker;
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
std::string json_dump_marker;
};
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
struct common_json {
nlohmann::ordered_json json;
common_healing_marker healing_marker;
};
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
//
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
//
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out);
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out);
+102 -29
View File
@@ -921,6 +921,10 @@ struct parser_executor {
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
return arena.parse(p.child, ctx, start_pos);
}
common_peg_parse_result operator()(const common_peg_ac_parser & p) {
return arena.parse(p.child, ctx, start_pos);
}
};
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
@@ -989,7 +993,8 @@ void common_peg_arena::resolve_refs() {
std::is_same_v<T, common_peg_not_parser> ||
std::is_same_v<T, common_peg_tag_parser> ||
std::is_same_v<T, common_peg_atomic_parser> ||
std::is_same_v<T, common_peg_gbnf_parser>) {
std::is_same_v<T, common_peg_gbnf_parser> ||
std::is_same_v<T, common_peg_ac_parser>) {
p.child = resolve_ref(p.child);
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
p.child = resolve_ref(p.child);
@@ -1070,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
return "Atomic(" + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")";
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
return "Any";
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
@@ -1479,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
});
}
common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector<std::string> & delimiters) {
if (delimiters.empty()) {
throw std::runtime_error("ac parser requires at least one delimiter");
}
return add(common_peg_ac_parser{p, delimiters});
}
static std::string gbnf_escape_char_class(uint32_t c) {
if (c == '-' || c == ']' || c == '[' || c == '\\') {
return "\\" + std::string(1, (char) c);
@@ -1529,14 +1543,22 @@ static std::string gbnf_escape_char_class(uint32_t c) {
return std::string(buf);
}
// GBNF grammar matching strings that contain no string in `strings` as a
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
// the start state rule name.
//
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
static std::string gbnf_char_class(const std::vector<uint32_t> & chars, bool negate) {
std::string s = negate ? "[^" : "[";
for (uint32_t ch : chars) {
s += gbnf_escape_char_class(ch);
}
return s + "]";
}
static std::string gbnf_ac_grammar(
const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings,
const std::function<std::string(const std::vector<uint32_t> &,
const std::map<size_t, std::vector<uint32_t>> &,
const std::vector<uint32_t> &,
const std::function<std::string(size_t)> &)> & build_rule) {
aho_corasick ac(strings);
auto state_name = [&](size_t s) -> std::string {
@@ -1548,42 +1570,30 @@ static std::string gbnf_excluding_grammar(const common_grammar_builder & builder
return prefix + "-" + num;
};
auto char_class = [](const std::vector<uint32_t> & chars, bool negate) {
std::string s = negate ? "[^" : "[";
for (uint32_t ch : chars) {
s += gbnf_escape_char_class(ch);
}
return s + "]";
};
for (size_t q = 0; q < ac.num_states(); q++) {
if (ac.is_terminal(q)) {
continue; // match states are dropped
continue; // match states
}
std::map<size_t, std::vector<uint32_t>> buckets;
std::vector<uint32_t> excluded;
std::vector<uint32_t> completing; // chars that complete a delimiter
std::vector<uint32_t> specific; // chars with an explicit transition
for (uint32_t c : ac.alphabet) {
size_t d = ac.next(q, c);
if (ac.is_terminal(d)) {
excluded.push_back(c); // completes a forbidden string -> omit
completing.push_back(c);
specific.push_back(c);
} else if (d != 0) {
buckets[d].push_back(c); // specific non-root destination
excluded.push_back(c);
specific.push_back(c);
}
}
std::string rhs = "|"; // every state is accepting
for (const auto & [d, chars] : buckets) {
rhs += " " + char_class(chars, false) + " " + state_name(d) + " |";
}
rhs += " " + char_class(excluded, true) + " " + state_name(0);
builder.add_rule(state_name(q), rhs);
builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name));
}
// An empty delimiter makes the start state terminal. Emit an entry rule
// that matches nothing so the returned reference stays valid.
// that matches the empty string so the returned reference stays valid.
if (ac.is_terminal(0)) {
builder.add_rule(prefix, "|");
}
@@ -1591,6 +1601,54 @@ static std::string gbnf_excluding_grammar(const common_grammar_builder & builder
return state_name(0);
}
// GBNF grammar matching strings that contain no string in `strings` as a
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
// the start state rule name.
//
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
return gbnf_ac_grammar(builder, prefix, strings,
[](const std::vector<uint32_t> & /*completing*/,
const std::map<size_t, std::vector<uint32_t>> & buckets,
const std::vector<uint32_t> & specific,
const std::function<std::string(size_t)> & state_name) {
// every state is accepting and completing chars get no
// alternative, so a forbidden string can never be matched
std::string rhs = "|";
for (const auto & [d, chars] : buckets) {
rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |";
}
rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0);
return rhs;
});
}
// GBNF grammar matching everything up to and including the first occurrence of
// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns
// the start state rule name.
static std::string gbnf_including_grammar(const common_grammar_builder & builder,
const std::string & prefix,
const std::vector<std::string> & strings) {
return gbnf_ac_grammar(builder, prefix, strings,
[](const std::vector<uint32_t> & completing,
const std::map<size_t, std::vector<uint32_t>> & buckets,
const std::vector<uint32_t> & specific,
const std::function<std::string(size_t)> & state_name) {
std::vector<std::string> alts;
if (!completing.empty()) {
alts.push_back(gbnf_char_class(completing, false)); // terminate on match
}
for (const auto & [d, chars] : buckets) {
alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d));
}
// every other character keeps scanning from the start state
alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0));
return string_join(alts, " | ");
});
}
static std::set<std::string> collect_reachable_rules(
const common_peg_arena & arena,
const common_peg_parser_id & rule
@@ -1628,6 +1686,7 @@ static std::set<std::string> collect_reachable_rules(
std::is_same_v<T, common_peg_tag_parser> ||
std::is_same_v<T, common_peg_atomic_parser> ||
std::is_same_v<T, common_peg_gbnf_parser> ||
std::is_same_v<T, common_peg_ac_parser> ||
std::is_same_v<T, common_peg_schema_parser>) {
visit(p.child);
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
@@ -1822,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
return to_gbnf(p.child);
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return p.grammar;
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters);
} else {
static_assert(is_always_false_v<T>);
}
@@ -1958,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
};
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}};
}
}, variant);
}
@@ -2130,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
};
}
if (type == "ac") {
if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) {
throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array");
}
return common_peg_ac_parser{
j["child"].get<common_peg_parser_id>(),
j["delimiters"].get<std::vector<std::string>>(),
};
}
throw std::runtime_error("Unknown parser type: " + type);
}
+14 -1
View File
@@ -275,6 +275,11 @@ struct common_peg_gbnf_parser {
std::string grammar;
};
struct common_peg_ac_parser {
common_peg_parser_id child;
std::vector<std::string> delimiters;
};
// Variant holding all parser types
using common_peg_parser_variant = std::variant<
common_peg_epsilon_parser,
@@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant<
common_peg_ref_parser,
common_peg_atomic_parser,
common_peg_tag_parser,
common_peg_gbnf_parser
common_peg_gbnf_parser,
common_peg_ac_parser
>;
class common_peg_arena {
@@ -514,6 +520,13 @@ class common_peg_parser_builder {
// the child's grammar. Parsing delegates entirely to the child.
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
// Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick
// automaton of `delimiters`, matching everything up to and including the
// first delimiter. Parsing delegates entirely to the child, which is
// responsible for consuming the delimiter (e.g. until(D) + literal(D)).
common_peg_parser ac(const common_peg_parser & p, const std::vector<std::string> & delimiters);
common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector<std::string>{delimiter}); }
void set_root(const common_peg_parser & p);
common_peg_arena build();
+7
View File
@@ -46,6 +46,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"DbrxForCausalLM": "dbrx",
"DeciLMForCausalLM": "deci",
"DeepseekForCausalLM": "deepseek",
"DeepseekOCRForCausalLM": "deepseek",
"DeepseekV2ForCausalLM": "deepseek",
"DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek",
@@ -96,6 +97,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"GraniteMoeHybridForCausalLM": "granite",
"GraniteMoeSharedForCausalLM": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"Grok1ForCausalLM": "grok",
"GrokForCausalLM": "grok",
"GroveMoeForCausalLM": "grovemoe",
@@ -123,6 +125,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"LLaDAModelLM": "llada",
"LLaMAForCausalLM": "llama",
"Lfm25AudioTokenizer": "lfm2",
"Lfm2BidirectionalModel": "lfm2",
"Lfm2ForCausalLM": "lfm2",
"Lfm2Model": "lfm2",
"Lfm2MoeForCausalLM": "lfm2",
@@ -133,6 +136,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"LlamaModel": "llama",
"Eagle3DraftModel": "llama",
"Eagle3Speculator": "llama",
"Eagle3LlamaForCausalLM": "llama",
"LlamaForCausalLMEagle3": "llama",
"LlavaForConditionalGeneration": "llama",
"LlavaStableLMEpochForCausalLM": "stablelm",
@@ -231,6 +235,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
"UMT5ForConditionalGeneration": "t5",
"UMT5Model": "t5",
"UltravoxModel": "ultravox",
"UnlimitedOCRForCausalLM": "deepseek",
"VLlama3ForCausalLM": "llama",
"VoxtralForConditionalGeneration": "llama",
"WavTokenizerDec": "wavtokenizer",
@@ -261,6 +266,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
"GlmasrModel": "ultravox",
"Granite4VisionForConditionalGeneration": "granite",
"GraniteSpeechForConditionalGeneration": "granite",
"GraniteSpeechPlusForConditionalGeneration": "granite",
"HunYuanVLForConditionalGeneration": "hunyuan",
"Idefics3ForConditionalGeneration": "smolvlm",
"InternVisionModel": "internvl",
@@ -296,6 +302,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
"StepVLForConditionalGeneration": "step3",
"Step3p7ForConditionalGeneration": "step3",
"UltravoxModel": "ultravox",
"UnlimitedOCRForCausalLM": "deepseek",
"VoxtralForConditionalGeneration": "ultravox",
"YoutuVLForConditionalGeneration": "youtuvl",
}
+10 -2
View File
@@ -14,7 +14,7 @@ from .base import MmprojModel, ModelBase, TextModel, gguf, logger
from .qwen import QwenModel
@ModelBase.register("DeepseekOCRForCausalLM")
@ModelBase.register("DeepseekOCRForCausalLM", "UnlimitedOCRForCausalLM")
class DeepseekOCRVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -205,6 +205,8 @@ class DeepseekModel(TextModel):
@ModelBase.register(
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekOCRForCausalLM",
"UnlimitedOCRForCausalLM",
"KimiVLForConditionalGeneration",
"KimiK25ForConditionalGeneration",
"YoutuForCausalLM",
@@ -224,7 +226,7 @@ class DeepseekV2Model(TextModel):
self.origin_hf_arch = hparams.get('architectures', [None])[0]
# special handling for Deepseek OCR
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM"):
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM", "UnlimitedOCRForCausalLM"):
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
self.gguf_writer.add_architecture()
@@ -350,6 +352,12 @@ class DeepseekV2Model(TextModel):
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
# Unlimited-OCR sliding window; written for metadata, the decoder ignores it (full MHA)
if is_ocr:
sliding_window = hparams.get("sliding_window_size") or hparams.get("sliding_window")
if sliding_window:
self.gguf_writer.add_sliding_window(sliding_window)
if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
+28
View File
@@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
has_vision_encoder = False
has_audio_encoder = True
def set_gguf_parameters(self):
assert self.hparams_audio is not None
super().set_gguf_parameters()
# Add feature_layer if present in encoder config
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
self.gguf_writer.add_audio_feature_layers(feature_layers)
logger.info(f"gguf: audio feature_layers = {feature_layers}")
# Validate projector dimension matches concatenated encoder output
hidden_dim = self.hparams_audio["hidden_dim"]
expected_dim = hidden_dim * (len(feature_layers) + 1)
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
if projector_dim != expected_dim:
raise ValueError(
f"Projector encoder_hidden_size ({projector_dim}) does not match "
f"expected concatenated dimension ({expected_dim}). "
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
)
@ModelBase.register("Granite4VisionForConditionalGeneration")
class Granite4VisionMmprojModel(MmprojModel):
has_vision_encoder = True
+10 -3
View File
@@ -64,11 +64,17 @@ class LFM2Model(TextModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Lfm2Model")
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
class LFM2ColBertModel(LFM2Model):
model_arch = gguf.MODEL_ARCH.LFM2
dense_tensor_name = "dense_2"
def set_gguf_parameters(self):
super().set_gguf_parameters()
if self.hf_arch == "Lfm2BidirectionalModel":
self.gguf_writer.add_causal_attention(False)
self._try_set_pooling_type()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith(self.dense_tensor_name):
name = "model." + name
@@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model):
yield from super().modify_tensors(data_torch, name, bid)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# dense tensor is stored in a separate safetensors file
# optional dense tensor is stored in a separate safetensors file
from safetensors.torch import load_file
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
assert tensors_file.is_file()
if not tensors_file.is_file():
return
tensor = load_file(tensors_file)["linear.weight"]
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
yield f"{self.dense_tensor_name}.weight", tensor.clone()
+1
View File
@@ -23,6 +23,7 @@ from .base import ModelBase, TextModel, gguf, logger
"LlavaForConditionalGeneration",
"VoxtralForConditionalGeneration",
"LlamaForCausalLMEagle3",
"Eagle3LlamaForCausalLM",
"Eagle3Speculator",
"Eagle3DraftModel",
"IQuestCoderForCausalLM",
+3 -4
View File
@@ -114,7 +114,8 @@ class Mamba2Model(TextModel):
hparams["text_config"] = hparams["llm_config"]
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
self.expand = self.find_hparam(["mamba_expand", "expand"], optional=True) or 2
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or self.expand * self.d_model
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1
def set_vocab(self):
@@ -144,11 +145,9 @@ class Mamba2Model(TextModel):
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
# skip the assertion for FalconH1 Model
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
assert self.d_inner == 2 * self.d_model
assert self.d_inner == self.expand * self.d_model
assert self.d_inner % head_dim == 0
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
+1 -1
View File
@@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu
```
$ apt update && apt upgrade -y
$ apt install git cmake
$ apt install git cmake libandroid-spawn
```
Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake.
+6 -6
View File
@@ -237,8 +237,8 @@ chmod +x ubuntu-llamacpp-ov-install.sh
# ============================================
set -euo pipefail
OPENVINO_VERSION_MAJOR="2026.2"
OPENVINO_VERSION_FULL="2026.2.0.21903.52ddc073857"
OPENVINO_VERSION_MAJOR="2026.2.1"
OPENVINO_VERSION_FULL="2026.2.1.21919.ede283a88e3"
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OPENVINO_INSTALL_DIR="/opt/intel/openvino_${OPENVINO_VERSION_MAJOR}"
@@ -334,7 +334,7 @@ echo " ./build/ReleaseOV/bin/llama-cli -m model.gguf"
```
> [!NOTE]
> The script pins OpenVINO `2026.2` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release.
> The script pins OpenVINO `2026.2.1` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release.
</details>
@@ -364,8 +364,8 @@ REM ============================================
REM llama.cpp OpenVINO Build Script (Ninja)
REM ============================================
set "OPENVINO_VERSION_MAJOR=2026.2"
set "OPENVINO_VERSION_FULL=2026.2.0.21903.52ddc073857"
set "OPENVINO_VERSION_MAJOR=2026.2.1"
set "OPENVINO_VERSION_FULL=2026.2.1.21919.ede283a88e3"
set "SCRIPT_DIR=%~dp0"
set "VCPKG_DIR=C:\vcpkg"
@@ -547,7 +547,7 @@ endlocal
```
> [!NOTE]
> The script pins OpenVINO `2026.2` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release. From any new shell, source the matching `setupvars` script via the junction — `call "C:\Intel\openvino\setupvars.bat"` from `cmd`, or `& "C:\Intel\openvino\setupvars.ps1"` from PowerShell. If `winget` cannot register Visual Studio Build Tools on first run, install them once manually and re-run the script from an elevated **Developer Command Prompt for VS 2022**.
> The script pins OpenVINO `2026.2.1` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release. From any new shell, source the matching `setupvars` script via the junction — `call "C:\Intel\openvino\setupvars.bat"` from `cmd`, or `& "C:\Intel\openvino\setupvars.ps1"` from PowerShell. If `winget` cannot register Visual Studio Build Tools on first run, install them once manually and re-run the script from an elevated **Developer Command Prompt for VS 2022**.
</details>
+18
View File
@@ -413,6 +413,15 @@ In two device selection modes, the default SYCL backend is level_zero, you can c
|------------------|----------------------------------------|
| Single device | --split-mode none --main-gpu DEVICE_ID |
| Multiple devices | --split-mode layer (default) |
| Multiple devices | --split-mode tensor (tensor parallelism) |
`--split-mode tensor` (tensor parallelism) shards each layer across the selected
GPUs. It requires flash attention, which is auto-enabled when `--flash-attn` is
left at its default `auto`, so `--split-mode tensor` works out of the box.
Passing `--flash-attn off` together with `--split-mode tensor` is rejected at
context creation. The default `f16` KV cache is recommended. Tensor parallelism
is currently optimized for 2 GPUs; other device counts fall back to a generic
all-reduce.
Examples:
@@ -715,6 +724,15 @@ In two device selection modes, the default SYCL backend is level_zero, you can c
|------------------|----------------------------------------|
| Single device | --split-mode none --main-gpu DEVICE_ID |
| Multiple devices | --split-mode layer (default) |
| Multiple devices | --split-mode tensor (tensor parallelism) |
`--split-mode tensor` (tensor parallelism) shards each layer across the selected
GPUs. It requires flash attention, which is auto-enabled when `--flash-attn` is
left at its default `auto`, so `--split-mode tensor` works out of the box.
Passing `--flash-attn off` together with `--split-mode tensor` is rejected at
context creation. The default `f16` KV cache is recommended. Tensor parallelism
is currently optimized for 2 GPUs; other device counts fall back to a generic
all-reduce.
Examples:
@@ -24,7 +24,6 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_OPENSSL": "OFF"
}
},
@@ -47,7 +46,6 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "ON",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_OPENSSL": "OFF"
}
},
@@ -73,7 +71,6 @@
"GGML_LLAMAFILE": "OFF",
"GGML_OPENCL": "OFF",
"GGML_HEXAGON": "ON",
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
"LLAMA_OPENSSL": "OFF"
}
},
+41 -1
View File
@@ -13,6 +13,45 @@ The `llama-server` application supports several implementations of speculative d
A much smaller model (called the _draft model_) generates drafts.
A draft model is the most used approach in speculative decoding.
### EAGLE-3 (`draft-eagle3`)
EAGLE-3 uses a small draft model that reads the target model's hidden states to predict the next tokens, so it
reaches higher acceptance than a standalone draft model of the same size. The draft is a one-layer transformer
trained for a specific target model; it shares the target model's tokenizer and, optionally, uses a reduced draft
vocabulary with its own `lm_head`, which is mapped back using a `d2t` table.
Convert the EAGLE-3 checkpoint with `--target-model-dir` so it inherits the target's tokenizer and the layer
indices to read. Both the SpecForge `LlamaForCausalLMEagle3` and the vLLM/AngelSlim `Eagle3LlamaForCausalLM`
checkpoint formats are supported (for example [`AngelSlim/Qwen3-4B_eagle3`](https://huggingface.co/AngelSlim/Qwen3-4B_eagle3)
for `Qwen/Qwen3-4B`):
```bash
python convert_hf_to_gguf.py AngelSlim/Qwen3-4B_eagle3 \
--target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-eagle3.gguf
llama-server -m Qwen3-4B.gguf -md Qwen3-4B-eagle3.gguf --spec-type draft-eagle3
```
Supported EAGLE-3 draft models include:
- [yuhuili/EAGLE3-LLaMA3.1-Instruct-8B](https://huggingface.co/yuhuili/EAGLE3-LLaMA3.1-Instruct-8B)
- [yuhuili/EAGLE3-LLaMA3.3-Instruct-70B](https://huggingface.co/yuhuili/EAGLE3-LLaMA3.3-Instruct-70B)
- [RedHatAI/gemma-4-31B-it-speculator.eagle3](https://huggingface.co/RedHatAI/gemma-4-31B-it-speculator.eagle3)
- [RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3](https://huggingface.co/RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3)
- [Tengyunw/qwen3_8b_eagle3](https://huggingface.co/Tengyunw/qwen3_8b_eagle3)
- [Tengyunw/qwen3_30b_moe_eagle3](https://huggingface.co/Tengyunw/qwen3_30b_moe_eagle3)
- [AngelSlim/Qwen3-1.7B_eagle3](https://huggingface.co/AngelSlim/Qwen3-1.7B_eagle3)
- [AngelSlim/Qwen3-4B_eagle3](https://huggingface.co/AngelSlim/Qwen3-4B_eagle3)
- [AngelSlim/Qwen3-8B_eagle3](https://huggingface.co/AngelSlim/Qwen3-8B_eagle3)
- [AngelSlim/Qwen3-14B_eagle3](https://huggingface.co/AngelSlim/Qwen3-14B_eagle3)
- [AngelSlim/Qwen3-32B_eagle3](https://huggingface.co/AngelSlim/Qwen3-32B_eagle3)
- [AngelSlim/Qwen3-a3B_eagle3](https://huggingface.co/AngelSlim/Qwen3-a3B_eagle3)
- [RedHatAI/gpt-oss-20b-speculator.eagle3](https://huggingface.co/RedHatAI/gpt-oss-20b-speculator.eagle3)
- [lmsys/EAGLE3-gpt-oss-120b-bf16](https://huggingface.co/lmsys/EAGLE3-gpt-oss-120b-bf16)
- [nvidia/gpt-oss-120b-Eagle3-long-context](https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-long-context)
For the full and up-to-date list of supported models, see #18039.
### n-gram Cache (`ngram-cache`)
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
@@ -108,7 +147,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
### General Speculative Parameters
```
--spec-type [none|draft-simple|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
comma-separated list of types of speculative decoding to use
(default: none)
(env: LLAMA_ARG_SPEC_TYPE)
@@ -247,6 +286,7 @@ Specifies a comma-separated list of speculative decoding types to use.
|------|-------------|
| `none` | No speculative decoding (default) |
| `draft-simple` | Use a simple draft model for speculation |
| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states |
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
| `ngram-cache` | Use n-gram cache lookup |
| `ngram-simple` | Use simple n-gram pattern matching |
+1 -2
View File
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
### GGML Version
set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 15)
set(GGML_VERSION_PATCH 2)
set(GGML_VERSION_PATCH 3)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
@@ -266,7 +266,6 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
"ggml: OpenCL API version to target")
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
# toolchain for vulkan-shaders-gen
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
+8
View File
@@ -27,6 +27,14 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int de
// split tensor buffer that splits matrices by rows across multiple devices
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
// Tensor parallelism (--split-mode tensor): comm_init/free/allreduce_tensor
// trio queried by the meta-backend via ggml_backend_reg_get_proc_address.
// See typedefs in ggml/include/ggml-backend.h. Mirrors the CUDA backend's
// pattern (ggml_backend_cuda_comm_*).
GGML_BACKEND_API void * ggml_backend_sycl_comm_init(ggml_backend_t * backends, size_t n_backends);
GGML_BACKEND_API void ggml_backend_sycl_comm_free(void * comm_ctx);
GGML_BACKEND_API bool ggml_backend_sycl_comm_allreduce_tensor(void * comm_ctx, struct ggml_tensor ** tensors);
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
+7 -3
View File
@@ -1551,6 +1551,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
int split_backend_id = split->backend_id;
ggml_backend_t split_backend = sched->backends[split_backend_id];
ggml_backend_synchronize(split_backend);
// copy the input tensors to the split backend
for (int input_id = 0; input_id < split->n_inputs; input_id++) {
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
@@ -1561,15 +1563,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
} else {
} else if (!split_backend->iface.cpy_tensor_async) {
ggml_backend_synchronize(split_backend);
}
ggml_backend_tensor_copy(input, input_cpy);
ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
} else {
// wait for the split backend to finish using the input before overwriting it
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
} else {
} else if (!split_backend->iface.cpy_tensor_async) {
ggml_backend_synchronize(split_backend);
}
@@ -1674,6 +1676,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
}
}
ggml_backend_synchronize(split_backend);
if (!sched->callback_eval) {
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
if (ec != GGML_STATUS_SUCCESS) {
+50 -23
View File
@@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, x);
float mean = sum/ne00;
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
const float * xf = (const float *) x;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
float variance = 0;
float sum = 0.0;
ggml_vec_sum_f32(ne00, &sum, xf);
float mean = sum/ne00;
float * yf = (float *) y;
float variance = 0;
#ifdef GGML_USE_ACCELERATE
mean = -mean;
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
vDSP_measqv(y, 1, &variance, ne00);
mean = -mean;
vDSP_vsadd(xf, 1, &mean, yf, 1, ne00);
vDSP_measqv(yf, 1, &variance, ne00);
#else
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
variance = ggml_vec_cvar_f32(ne00, yf, xf, mean);
#endif //GGML_USE_ACCELERATE
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, y, scale);
const float scale = 1.0f/sqrtf(variance + eps);
ggml_vec_scale_f32(ne00, yf, scale);
} else {
float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += *(const float *) (x + i00*nb00);
}
const float mean = sum/ne00;
float variance = 0.0f;
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float v = *(const float *) (x + i00*nb00) - mean;
*(float *) (y + i00*nb0) = v;
variance += v * v;
}
variance /= ne00;
const float scale = 1.0f/sqrtf(variance + eps);
for (int64_t i00 = 0; i00 < ne00; i00++) {
*(float *) (y + i00*nb0) *= scale;
}
}
}
}
}
@@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32(
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
const float xi = *(const float *) (x + i00*nb00);
sum += (ggml_float)(xi * xi);
}
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
ggml_vec_scale_f32(ne00, y, scale);
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
memcpy(y, x, ne00 * sizeof(float));
ggml_vec_scale_f32(ne00, (float *) y, scale);
} else {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const float xi = *(const float *) (x + i00*nb00);
*(float *) (y + i00*nb0) = xi * scale;
}
}
}
}
}
+2 -2
View File
@@ -75,12 +75,12 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
ay1 = GGML_F32_VEC_LOAD(y + i);
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
}
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmla on available elements only
if (np2 < n) {
svbool_t pg = svwhilelt_b32(np2, n);
ax1 = svld1_f32(pg, x + np2);
ay1 = svld1_f32(pg, y + np2);
sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
sum1 = svmla_f32_m(pg, sum1, ax1, ay1);
}
// reduce sum1,sum2 to sum1
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
+90 -46
View File
@@ -34,26 +34,26 @@ template <float (*bin_op)(const float, const float),
static __global__ void k_bin_bcast(const src0_t * src0,
const src1_t * src1,
dst_t * dst,
const int ne0,
const int ne1,
const int ne2,
const uint32_t ne0,
const uint32_t ne1,
const uint32_t ne2,
const uint3 ne3,
const uint3 ne10,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*const int s0,*/
const int s1,
const int s2,
const int s3,
const int s00,
const int s01,
const int s02,
const int s03,
const int s10,
const int s11,
const int s12,
const int s13,
/*const uint32_t s0,*/
const uint32_t s1,
const uint32_t s2,
const uint32_t s3,
const uint32_t s00,
const uint32_t s01,
const uint32_t s02,
const uint32_t s03,
const uint32_t s10,
const uint32_t s11,
const uint32_t s12,
const uint32_t s13,
src1_ptrs... src1s) {
ggml_cuda_pdl_lc();
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
@@ -61,7 +61,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
return;
}
@@ -69,25 +69,32 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const uint32_t i12 = fastmodulo(i2, ne12);
const uint32_t i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01;
const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11;
const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1;
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
const uint32_t s0 = blockDim.x * gridDim.x;
ggml_cuda_pdl_sync();
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
for (uint32_t i0 = i0s; i0 < ne0; i0 += s0) {
const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]);
}
dst_row[i0] = (dst_t) result;
// protect i0 from overflow
if (ne0 - i0 <= s0) {
break;
}
}
}
@@ -110,19 +117,19 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const uint3 ne12,
const uint3 ne13,
/*const int s0,*/
const int s1,
const int s2,
const int s3,
const int s00,
const int s01,
const int s02,
const int s03,
const int s10,
const int s11,
const int s12,
const int s13,
const uint32_t s1,
const uint32_t s2,
const uint32_t s3,
const uint32_t s00,
const uint32_t s01,
const uint32_t s02,
const uint32_t s03,
const uint32_t s10,
const uint32_t s11,
const uint32_t s12,
const uint32_t s13,
src1_ptrs... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const uint32_t i = blockDim.x*blockIdx.x + threadIdx.x;
const uint32_t i3 = fastdiv(i, prod_012);
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
@@ -133,25 +140,25 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
return;
}
const int i11 = fastmodulo(i1, ne11);
const int i12 = fastmodulo(i2, ne12);
const int i13 = fastmodulo(i3, ne13);
const uint32_t i11 = fastmodulo(i1, ne11);
const uint32_t i12 = fastmodulo(i2, ne12);
const uint32_t i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01;
const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11;
const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1;
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;
const int i10 = fastmodulo(i0, ne10);
const uint32_t i10 = fastmodulo(i0, ne10);
ggml_cuda_pdl_sync();
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]);
}
dst_row[i0] = (dst_t) result;
@@ -248,6 +255,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
GGML_ASSERT(ne0 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(ne1 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(ne2 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(ne3 <= std::numeric_limits<uint32_t>::max());
//GGML_ASSERT(s0 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s1 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s2 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s3 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s00 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s01 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s02 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s03 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s10 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s11 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s12 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(s13 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(cne1[0] <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(cne1[1] <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(cne1[2] <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(cne1[3] <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
@@ -263,6 +295,8 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(ne2 * ne3 <= std::numeric_limits<unsigned int>::max());
const int block_size = 128;
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
@@ -281,7 +315,13 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
if (block_nums.z > 65535 || block_nums.y > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
int64_t block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
GGML_ASSERT(block_num <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(block_num * block_size <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(ne0 * ne1 <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(ne0 * ne1 * ne2 <= std::numeric_limits<uint32_t>::max());
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
@@ -298,6 +338,10 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
}
} else {
GGML_ASSERT(int64_t(block_nums.x) * block_dims.x <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(int64_t(block_nums.y) * block_dims.y <= std::numeric_limits<uint32_t>::max());
GGML_ASSERT(int64_t(block_nums.z) * block_dims.z <= std::numeric_limits<uint32_t>::max());
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
{
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
+80 -29
View File
@@ -53,10 +53,10 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
const int64_t nmat = ne / (ne00 * ne01);
const int64_t n = ne00 * ne01;
const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
const int64_t x = (int64_t) blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
const int64_t y = (int64_t) blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
const int64_t tx = (int64_t) blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
const int64_t ty = (int64_t) blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
__shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
int cur_tile_buf = 0;
@@ -197,7 +197,7 @@ static void ggml_cpy_scalar_contiguous_cuda(
cudaStream_t stream) {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar_contiguous<src_t, dst_t>, launch_params, cx, cdst, ne);
}
@@ -208,6 +208,14 @@ static void ggml_cpy_scalar_cuda(
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const auto launch_scalar_generic = [&]() {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks <= INT_MAX);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params,
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
};
if (transposed) {
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
int64_t ne00n, ne01n, ne02n;
@@ -224,20 +232,18 @@ static void ggml_cpy_scalar_cuda(
int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
GGML_ASSERT(grid_x < UINT_MAX);
GGML_ASSERT(grid_y < USHRT_MAX);
GGML_ASSERT(grid_z < USHRT_MAX);
dim3 dimGrid(grid_x, grid_y, grid_z);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params,
cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
GGML_ASSERT(grid_x <= INT_MAX);
if (grid_y > USHRT_MAX || grid_z > USHRT_MAX) {
launch_scalar_generic();
} else {
dim3 dimGrid(grid_x, grid_y, grid_z);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params,
cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
} else {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params,
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
launch_scalar_generic();
}
}
@@ -248,7 +254,7 @@ static void ggml_cpy_f32_q8_0_cuda(
GGML_ASSERT(ne % QK8_0 == 0);
const int64_t num_blocks = ne / QK8_0;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
@@ -259,7 +265,7 @@ static void ggml_cpy_q8_0_f32_cuda(
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
@@ -271,7 +277,7 @@ static void ggml_cpy_f32_q4_0_cuda(
GGML_ASSERT(ne % QK4_0 == 0);
const int64_t num_blocks = ne / QK4_0;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
@@ -284,7 +290,7 @@ static void ggml_cpy_q4_0_f32_cuda(
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
@@ -297,7 +303,7 @@ static void ggml_cpy_f32_q4_1_cuda(
GGML_ASSERT(ne % QK4_1 == 0);
const int64_t num_blocks = ne / QK4_1;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
@@ -310,7 +316,7 @@ static void ggml_cpy_q4_1_f32_cuda(
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
@@ -323,7 +329,7 @@ static void ggml_cpy_f32_q5_0_cuda(
GGML_ASSERT(ne % QK5_0 == 0);
const int64_t num_blocks = ne / QK5_0;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
@@ -336,7 +342,7 @@ static void ggml_cpy_q5_0_f32_cuda(
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
@@ -349,7 +355,7 @@ static void ggml_cpy_f32_q5_1_cuda(
GGML_ASSERT(ne % QK5_1 == 0);
const int64_t num_blocks = ne / QK5_1;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
@@ -362,7 +368,7 @@ static void ggml_cpy_q5_1_f32_cuda(
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
cudaStream_t stream) {
const int64_t num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
@@ -375,11 +381,51 @@ static void ggml_cpy_f32_iq4_nl_cuda(
GGML_ASSERT(ne % QK4_NL == 0);
const int64_t num_blocks = ne / QK4_NL;
GGML_ASSERT(num_blocks < UINT_MAX);
GGML_ASSERT(num_blocks <= INT_MAX);
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
// check if a same-type copy reduces to a 2D strided copy (height rows of width
// contiguous bytes), so it can use cudaMemcpy2DAsync instead of the scalar kernel
static bool ggml_cuda_cpy_as_memcpy_2d(const ggml_tensor * src0, const ggml_tensor * src1,
size_t & width, size_t & height, size_t & spitch, size_t & dpitch) {
// require matching shape: a reshaped copy maps elements by flat order, which the
// prefix walk below does not handle
if (src0->type != src1->type || !ggml_are_same_shape(src0, src1)) {
return false;
}
// grow the contiguous prefix block shared by both tensors
size_t block_nb = ggml_element_size(src0);
int d = 0;
for (; d < GGML_MAX_DIMS; ++d) {
if (src0->nb[d] != block_nb || src1->nb[d] != block_nb) {
break;
}
block_nb *= src0->ne[d];
}
// d == 0: nothing contiguous; d == GGML_MAX_DIMS: fully contiguous (handled by memcpy)
if (d == 0 || d == GGML_MAX_DIMS) {
return false;
}
// dim d carries the rows; everything above it must be a single element
for (int i = d + 1; i < GGML_MAX_DIMS; ++i) {
if (src0->ne[i] != 1) {
return false;
}
}
width = block_nb;
height = src0->ne[d];
spitch = src0->nb[d];
dpitch = src1->nb[d];
return spitch >= width && dpitch >= width;
}
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
@@ -415,6 +461,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
size_t mc_width = 0, mc_height = 0, mc_spitch = 0, mc_dpitch = 0;
if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
@@ -425,6 +473,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
{
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (ggml_cuda_cpy_as_memcpy_2d(src0, src1, mc_width, mc_height, mc_spitch, mc_dpitch)) {
CUDA_CHECK(cudaMemcpy2DAsync(src1_ddc, mc_dpitch, src0_ddc, mc_spitch,
mc_width, mc_height, cudaMemcpyDeviceToDevice, main_stream));
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
if (can_be_transposed) {
ggml_cpy_scalar_cuda<float, float, true>
+21 -5
View File
@@ -3192,11 +3192,24 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
// Enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA
// Excluding this path for HIP and MUSA as a precaution.
// According to the summary in https://github.com/ggml-org/llama.cpp/pull/20793#issuecomment-4275794315, this change is not beneficial for hip anyways.
// Additionally, there is a lot of anectodal evidence that hip/musa stream behavior might not always 1:1 match CUDA behavior.
// e.g. https://github.com/ROCm/rocm-systems/issues/5109
// It thus makes sense to exclude this path for HIP and MUSA. This PR was not aimed these backends, the majority of testing happened on CUDA.
// This can be revisited in the future if enabling copy_from_host benefits hip/MUSA, and if the PR author can extensively test on these backends.
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA)
const bool copy_from_host = false;
#else
const bool copy_from_host = ggml_backend_buffer_is_host(buf_src) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU;
#endif
if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) {
return false;
}
if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) {
if (!(copy_from_host || ggml_backend_buffer_is_cuda(buf_src)) || !ggml_backend_buffer_is_cuda(buf_dst)) {
return false;
}
@@ -3207,14 +3220,17 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context;
ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context;
if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) ||
!copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
#endif // NDEBUG
return false;
}
if (backend_src != backend_dst) {
if (copy_from_host) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream()));
} else if (backend_src != backend_dst) {
// copy on src stream
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
@@ -5334,7 +5350,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return true;
return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]);
break;
+55 -12
View File
@@ -2,6 +2,28 @@
#include <cstdint>
static __global__ void k_compute_out_prod_ptrs(
const float * src0_d, const float * src1_d, float * dst_d,
const float ** ptrs_a, const float ** ptrs_b, float ** ptrs_c,
const int64_t ne2, const int64_t ne3,
const int64_t dps2, const int64_t dps3,
const size_t s02, const size_t s03,
const size_t s12, const size_t s13,
const size_t s2, const size_t s3) {
const int64_t i2 = blockIdx.x*blockDim.x + threadIdx.x;
const int64_t i3 = blockIdx.y*blockDim.y + threadIdx.y;
if (i2 >= ne2 || i3 >= ne3) {
return;
}
const int64_t idx = i3*ne2 + i2;
ptrs_a[idx] = src0_d + (i3/dps3)*s03 + (i2/dps2)*s02;
ptrs_b[idx] = src1_d + i3 *s13 + i2 *s12;
ptrs_c[idx] = dst_d + i3 *s3 + i2 *s2;
}
void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -67,18 +89,39 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
&beta, dst_d + i3 *s3, ldc, s2,
batch_count));
}
} else if (ne2 > 1 || ne3 > 1) {
// dps2 > 1 (src0 broadcast along dim 2 with non-uniform stride) or multiple GEMMs
// along dim 3: compute per-GEMM pointers on the device and use a single batched GEMM.
GGML_ASSERT(ne3 > 0);
GGML_ASSERT(ne2 <= (int64_t) std::numeric_limits<int>::max() / ne3);
const int batch_count = (int) (ne2 * ne3);
ggml_cuda_pool_alloc<const float *> ptrs_a(ctx.pool(), batch_count);
ggml_cuda_pool_alloc<const float *> ptrs_b(ctx.pool(), batch_count);
ggml_cuda_pool_alloc< float *> ptrs_c(ctx.pool(), batch_count);
const dim3 block_dims(16, 16);
const dim3 grid_dims((ne2 + block_dims.x - 1)/block_dims.x, (ne3 + block_dims.y - 1)/block_dims.y);
k_compute_out_prod_ptrs<<<grid_dims, block_dims, 0, stream>>>(
src0_d, src1_d, dst_d,
ptrs_a.get(), ptrs_b.get(), ptrs_c.get(),
ne2, ne3, dps2, dps3, s02, s03, s12, s13, s2, s3);
CUDA_CHECK(cudaGetLastError());
CUBLAS_CHECK(
cublasSgemmBatched(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, ptrs_a.get(), lda,
ptrs_b.get(), ldb,
&beta, ptrs_c.get(), ldc,
batch_count));
} else {
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
for (int64_t i3 = 0; i3 < ne3; ++i3) {
for (int64_t i2 = 0; i2 < ne2; ++i2) {
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
src1_d + i3 *s13 + i2 *s12, ldb,
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
}
}
// ne2 == 1 && ne3 == 1: single GEMM
CUBLAS_CHECK(
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
ne0, ne1, ne01,
&alpha, src0_d, lda,
src1_d, ldb,
&beta, dst_d, ldc));
}
}
+1
View File
@@ -48,6 +48,7 @@
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasSgemmBatched hipblasSgemmBatched
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
#define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t
+1
View File
@@ -32,6 +32,7 @@
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasSgemmBatched mublasSgemmBatched
#define cublasSgemmStridedBatched mublasSgemmStridedBatched
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
-4
View File
@@ -25,7 +25,6 @@ include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
add_library(htp_iface OBJECT
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
@@ -72,15 +71,12 @@ function(build_htp_skel V)
-DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
-DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
-DDSP_VERSION=${V}
-DPREBUILT_LIB_DIR="toolv19_${V}")
list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
endfunction()
build_htp_skel(v68)
build_htp_skel(v69)
build_htp_skel(v73)
build_htp_skel(v75)
build_htp_skel(v79)
File diff suppressed because it is too large Load Diff
+162 -56
View File
@@ -5,10 +5,12 @@
#include "ggml-backend-impl.h"
#include "ggml-common.h"
#include <algorithm>
#include <string>
#include <vector>
#include <stdio.h>
#include "htp-ops.h"
#include "htp/matmul-ops.h"
struct htp_opnode {
ggml_tensor * node = nullptr;
@@ -17,6 +19,13 @@ struct htp_opnode {
htp_op_code opcode = HTP_OP_INVALID;
std::vector<ggml_tensor *> extra_dsts;
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS] = {0};
htp_opnode(ggml_tensor * node = nullptr, std::vector<ggml_tensor *> fused = {}, htp_op_code opcode = HTP_OP_INVALID, std::vector<ggml_tensor *> extra_dsts = {})
: node(node), fused(std::move(fused)), opcode(opcode), extra_dsts(std::move(extra_dsts)) {}
ggml_op op() const {
return node->op;
}
@@ -25,6 +34,26 @@ struct htp_opnode {
return fused.empty() ? node : fused.back();
}
void add_fused(ggml_tensor * t, bool extra_dst = false) {
fused.push_back(t);
if (extra_dst) {
extra_dsts.push_back(t);
}
}
std::vector<const ggml_tensor *> get_outputs() const {
std::vector<const ggml_tensor *> res;
if (extra_dsts.empty()) {
res.push_back(dst());
} else {
res.push_back(node);
for (const auto * x : extra_dsts) {
res.push_back(x);
}
}
return res;
}
const ggml_tensor * src0() const {
return node->src[0];
}
@@ -37,10 +66,6 @@ struct htp_opnode {
return ggml_op_is_empty(node->op);
}
void add_fused(ggml_tensor * t) {
fused.push_back(t);
}
bool stackable() const {
switch (this->op()) {
case GGML_OP_MUL_MAT:
@@ -131,87 +156,117 @@ struct htp_opformat {
char types[16 * GGML_MAX_SRC];
char buffs[64 * GGML_MAX_SRC];
char names[64 * GGML_MAX_SRC];
char kparams[128];
int format_tensor_dims(char * str, const struct ggml_tensor * t) {
int format_tensor_dims(char * str, size_t max_size, const struct ggml_tensor * t) {
if (!t) {
return sprintf(str, "NONE");
return snprintf(str, max_size, "NONE");
}
if (t->ne[2] == 1 && t->ne[3] == 1) {
return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
return snprintf(str, max_size, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
} else {
return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
return snprintf(str, max_size, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
}
}
void format_op_dims(char * str, const htp_opnode & node) {
void format_op_dims(char * str, size_t max_size, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += format_tensor_dims(p, inputs[0]);
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[0]), (size_t)(p_end - p));
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += format_tensor_dims(p, inputs[i]);
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[i]), (size_t)(p_end - p));
}
}
p += sprintf(p, " -> ");
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
}
char self[64];
format_tensor_dims(self, node.dst());
p += sprintf(p, "%s", self);
format_tensor_dims(self, sizeof(self), node.dst());
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
}
}
int format_tensor_strides(char * str, const struct ggml_tensor * t) {
int format_tensor_strides(char * str, size_t max_size, const struct ggml_tensor * t) {
if (!t) {
return sprintf(str, "NONE");
return snprintf(str, max_size, "NONE");
}
const char * c = ggml_is_contiguous(t) ? "" : "!";
if (t->ne[2] == 1 && t->ne[3] == 1) {
return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
return snprintf(str, max_size, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
} else {
return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
return snprintf(str, max_size, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
}
}
void format_op_strides(char * str, const htp_opnode & node) {
void format_op_strides(char * str, size_t max_size, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += format_tensor_strides(p, inputs[0]);
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[0]), (size_t)(p_end - p));
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += format_tensor_strides(p, inputs[i]);
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[i]), (size_t)(p_end - p));
}
}
p += sprintf(p, " -> ");
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
}
char self[64];
format_tensor_strides(self, node.dst());
p += sprintf(p, "%s", self);
format_tensor_strides(self, sizeof(self), node.dst());
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
}
}
void format_op_types(char * str, const htp_opnode & node) {
void format_op_types(char * str, size_t max_size, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE");
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE");
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"), (size_t)(p_end - p));
}
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
}
p += sprintf(p, "%s", ggml_type_name(node.dst()->type));
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", ggml_type_name(node.dst()->type)), (size_t)(p_end - p));
}
}
const char * tensor_buff_name(const struct ggml_tensor * t) {
@@ -221,51 +276,102 @@ struct htp_opformat {
return "NONE";
}
void format_op_buffs(char * str, const htp_opnode & node) {
void format_op_buffs(char * str, size_t max_size, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += sprintf(p, "%s", tensor_buff_name(inputs[0]));
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", tensor_buff_name(inputs[i]));
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[0])), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[i])), (size_t)(p_end - p));
}
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
}
p += sprintf(p, "%s", tensor_buff_name(node.dst()));
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(node.dst())), (size_t)(p_end - p));
}
}
void format_op_names(char * str, const htp_opnode & node) {
void format_op_names(char * str, size_t max_size, const htp_opnode & node) {
char * p = str;
char * p_end = str + max_size;
auto inputs = node.get_inputs();
if (!inputs.empty()) {
p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE");
for (size_t i = 1; i < inputs.size(); i++) {
p += sprintf(p, " x ");
p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE");
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? inputs[0]->name : "NONE"), (size_t)(p_end - p));
}
p += sprintf(p, " -> ");
for (size_t i = 1; i < inputs.size(); i++) {
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? inputs[i]->name : "NONE"), (size_t)(p_end - p));
}
}
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
}
}
p += sprintf(p, "%s", node.dst()->name);
if (p < p_end) {
p += std::min((size_t)snprintf(p, p_end - p, "%s", node.dst()->name), (size_t)(p_end - p));
}
}
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
const char * path = "unknown";
int32_t type = kparams->kernel_type;
if (type == HTP_MM_KERNEL_HMX_2D || type == HTP_MM_KERNEL_HMX_F16_BATCHED) {
path = "hmx-tiled";
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || type == HTP_MM_KERNEL_HVX_F32_F32_VTCM ||
type == HTP_MM_KERNEL_HVX_QUANT_ROW || type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) {
path = "hvx-tiled";
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_DDR || type == HTP_MM_KERNEL_HVX_F16_F32_DDR ||
type == HTP_MM_KERNEL_HVX_F32_F32_DDR || type == HTP_MM_KERNEL_HVX_F32_F16_DDR ||
type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) {
path = "hvx-flat";
}
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
} else {
snprintf(str, max_size, "----");
}
}
void format(const htp_opnode & node) {
format_op_dims(dims, node);
format_op_strides(strides, node);
format_op_types(types, node);
format_op_buffs(buffs, node);
format_op_names(names, node);
format_op_dims(dims, sizeof(dims), node);
format_op_strides(strides, sizeof(strides), node);
format_op_types(types, sizeof(types), node);
format_op_buffs(buffs, sizeof(buffs), node);
format_op_names(names, sizeof(names), node);
format_kernel_params(kparams, sizeof(kparams), node);
}
htp_opformat() {}
htp_opformat() {
strides[0] = '\0';
dims[0] = '\0';
types[0] = '\0';
buffs[0] = '\0';
names[0] = '\0';
kparams[0] = '\0';
}
htp_opformat(const htp_opnode & node) { format(node); }
};
+14 -38
View File
@@ -19,43 +19,9 @@ add_library(${HTP_LIB} SHARED
htp_iface_skel.c
worker-pool.c
hex-dma.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
# HMX acceleration: available on v73+ architectures
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
)
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
target_sources(${HTP_LIB} PRIVATE
hmx-queue.c
flash-attn-ops.c
hmx-flash-attn-ops.c
matmul-ops.c
binary-ops.c
unary-ops.c
@@ -63,7 +29,6 @@ target_sources(${HTP_LIB} PRIVATE
softmax-ops.c
act-ops.c
rope-ops.c
flash-attn-ops.c
set-rows-ops.c
get-rows-ops.c
cpy-ops.c
@@ -79,6 +44,17 @@ target_sources(${HTP_LIB} PRIVATE
pad-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
if (GGML_HEXAGON_FA_EXP2_HF)
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
endif()
build_idl(htp_iface.idl ${HTP_LIB})
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
install(TARGETS ${HTP_LIB})
+13 -15
View File
@@ -3,7 +3,7 @@ if (HEXAGON_TOOLCHAIN_INCLUDED)
endif()
set(HEXAGON_TOOLCHAIN_INCLUDED true)
#Cross Compiling for Hexagon
# Cross Compiling for Hexagon
set(HEXAGON TRUE)
set(CMAKE_SYSTEM_NAME QURT)
set(CMAKE_SYSTEM_PROCESSOR Hexagon)
@@ -14,7 +14,6 @@ set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
set(CUSTOM_RUNELF_PATH "")
#To fix backward compatibility with EAI addon.
if (NOT HEXAGON_SDK_ROOT)
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
endif()
@@ -31,7 +30,6 @@ endif()
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
#Get the Binary extension of the Hexagon Toolchain
if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
endif()
@@ -48,12 +46,12 @@ set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
HEXAGON_TOOLS_ROOT
)
#QURT Related includes and linker flags
# QURT Related includes and linker flags
set(V_ARCH ${HEXAGON_ARCH})
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
if( ${TREE} MATCHES PAKMAN )
if (${TREE} MATCHES PAKMAN)
set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
endif()
message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
@@ -83,11 +81,9 @@ set(QURT_START_LINK_LIBS
)
STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
set(QURT_END_LINK_LIBS
${TARGET_DIR}/fini.o
)
set(QURT_END_LINK_LIBS ${TARGET_DIR}/fini.o)
#Non QURT related includes and linker flags
# Non QURT related includes and linker flags
set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
@@ -99,8 +95,10 @@ if (NOT NO_WRAP_MEM_API)
set(WRAP_MEMALIGN -Wl,--wrap=memalign)
endif()
set(ARCH_FLAGS "-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -mhmx")
set(PIC_SHARED_LD_FLAGS
-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
${ARCH_FLAGS}
-G0
-fpic
-Wl,-Bsymbolic
@@ -120,13 +118,13 @@ STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
#System include paths
# System include paths
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
#LLVM toolchain setup
#Compiler paths, options and architecture
# LLVM toolchain setup
# Compiler paths, options and architecture
set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
@@ -137,8 +135,8 @@ set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
#Compiler Options
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
# Compiler Options
set(COMMON_FLAGS "${ARCH_FLAGS} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
+2 -3
View File
@@ -18,7 +18,8 @@
#include "htp-ctx.h"
#include "htp-ops.h"
#include "htp-ops.h"
#include "hmx-ops.h"
int hmx_flash_attn_ext(struct htp_ops_context * octx);
// Must be multiple of 32
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
@@ -633,7 +634,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
return HTP_STATUS_NO_SUPPORT;
}
#ifdef HTP_HAS_HMX
// HMX path: head_dim multiple of 64, F16 KV, and no sinks
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) {
int ret = hmx_flash_attn_ext(octx);
@@ -642,7 +642,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
}
// VTCM too small or other failure -> fall through to HVX path
}
#endif
struct htp_fa_context factx;
factx.octx = octx;
+80
View File
@@ -0,0 +1,80 @@
#ifndef HEX_COMMON_H
#define HEX_COMMON_H
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#ifndef SIZE_MAX
#define SIZE_MAX ((size_t)-1)
#endif
#ifndef MAX
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#endif
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
static inline uint32_t hex_ceil_pow2(uint32_t x) {
if (x <= 1) { return 1; }
int p = 2;
x--;
while (x >>= 1) { p <<= 1; }
return p;
}
static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0;
}
static inline size_t hex_align_up(size_t v, size_t align) {
return hmx_ceil_div(v, align) * align;
}
static inline size_t hex_align_down(size_t v, size_t align) {
return (v / align) * align;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
return right_off <= chunk_size;
}
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
static inline size_t hex_smin(size_t a, size_t b) {
return a < b ? a : b;
}
static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_swap_ptr(void ** p1, void ** p2) {
void * t = *p1;
*p1 = *p2;
*p2 = t;
}
static inline bool hex_mul_overflow(size_t a, size_t b, size_t *out) {
if (a != 0 && b > SIZE_MAX / a) return true;
*out = a * b;
return false;
}
static inline bool hex_add_overflow(size_t a, size_t b, size_t *out) {
if (a > SIZE_MAX - b) return true;
*out = a + b;
return false;
}
#endif // HEX_COMMON_H
+1 -5
View File
@@ -5,6 +5,7 @@
#include <hexagon_types.h>
#include <stdbool.h>
#include <stdint.h>
#include "hex-utils.h"
#include "hex-profile.h"
@@ -127,13 +128,8 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src)
return p;
}
#if __HVX_ARCH__ < 73
static const uint32_t dma_src_l2_bypass_on = 1;
static const uint32_t dma_dst_l2_bypass_on = 0;
#else
static const uint32_t dma_src_l2_bypass_on = 1;
static const uint32_t dma_dst_l2_bypass_on = 1;
#endif
static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) {
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
+1 -56
View File
@@ -11,14 +11,7 @@
#include "hex-fastdiv.h"
#include "hex-dump.h"
#ifndef MAX
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#endif
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
#include "hex-common.h"
static inline uint64_t hex_get_cycles() {
uint64_t cycles = 0;
@@ -32,54 +25,6 @@ static inline uint64_t hex_get_pktcnt() {
return pktcnt;
}
static inline uint32_t hex_ceil_pow2(uint32_t x) {
if (x <= 1) { return 1; }
int p = 2;
x--;
while (x >>= 1) { p <<= 1; }
return p;
}
static inline size_t hmx_ceil_div(size_t num, size_t den) {
return (num + den - 1) / den;
}
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
return ((size_t) addr & (align - 1)) == 0;
}
static inline size_t hex_align_up(size_t v, size_t align) {
return hmx_ceil_div(v, align) * align;
}
static inline size_t hex_align_down(size_t v, size_t align) {
return (v / align) * align;
}
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
return right_off <= chunk_size;
}
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
return m * ((n + m - 1) / m);
}
static inline size_t hex_smin(size_t a, size_t b) {
return a < b ? a : b;
}
static inline size_t hex_smax(size_t a, size_t b) {
return a > b ? a : b;
}
static inline void hex_swap_ptr(void ** p1, void ** p2) {
void * t = *p1;
*p1 = *p2;
*p2 = t;
}
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
Q6_l2fetch_AP((void *) p, control);
+13 -13
View File
@@ -49,7 +49,7 @@
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) {
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
@@ -70,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV,
+ k_dma_size * 2 // K DMA x2
+ v_dma_size * 2 // V DMA x2
+ k_tile_size * 1 // K tiles
+ v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
+ s_tile_size * 2 // S + P
+ d_tile_size * 1 // D (diagonal matrix)
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
@@ -290,7 +290,7 @@ static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) =
struct hmx_fa_context {
const struct htp_ops_context * octx;
bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
bool pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
uint32_t n_threads;
// Op parameters
@@ -409,7 +409,7 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
return;
}
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
__fp16 * v_tiles_dest = factx->pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
@@ -1312,13 +1312,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS);
const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc;
const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
const bool pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
// Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1
const uint32_t n_threads = use_pipeline ? n_threads_init : 1;
const uint32_t n_threads = pipeline ? n_threads_init : 1;
FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu",
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget);
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, pipeline, vtcm_budget);
// ======== Build context ========
struct hmx_fa_context factx;
@@ -1339,7 +1339,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
factx.n_kv_blocks = n_kv_blocks;
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32);
factx.use_pipeline = use_pipeline;
factx.pipeline = pipeline;
factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1);
// Extract op parameters (mutable during softcap adjustment, then stored as const in factx)
@@ -1405,7 +1405,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes);
factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes);
factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
if (use_pipeline) {
if (pipeline) {
factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
} else {
factx.vtcm_v_tiles[1] = NULL;
@@ -1456,7 +1456,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
// ======== HMX lock strategy ========
// Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend.
// Fallback: main thread holds the lock (original behavior).
if (!factx.use_pipeline) {
if (!factx.pipeline) {
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
}
@@ -1550,7 +1550,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
const size_t k_src_stride = size_k_row_padded / sizeof(__fp16);
const size_t v_src_stride = size_v_row_padded / sizeof(__fp16);
if (factx.use_pipeline) {
if (factx.pipeline) {
// ==================================================================
// Pipeline path: HVX phases ‖ HMX queue worker
// ==================================================================
@@ -1780,7 +1780,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
// HMX: O_final = diag(1/l) @ O_prev
if (factx.use_pipeline) {
if (factx.pipeline) {
on_job.o_curr = o_tile_curr;
on_job.o_prev = o_tile_prev;
on_job.d_tiles = factx.vtcm_d_tiles;
@@ -1826,7 +1826,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
} // end KV head loop
} // end batch loop
if (factx.use_pipeline) {
if (factx.pipeline) {
hmx_queue_suspend(ctx->hmx_queue);
} else {
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
-6
View File
@@ -1,6 +0,0 @@
// HMX operations compiled as a single translation unit.
// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO.
#include "hmx-queue.c"
#include "hmx-matmul-ops.c"
#include "hmx-flash-attn-ops.c"
-88
View File
@@ -1,88 +0,0 @@
// HMX operation entry-point declarations.
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
#ifndef HMX_OPS_H
#define HMX_OPS_H
#include <stddef.h>
#include <stdint.h>
#include "htp-ops.h"
#ifdef __cplusplus
extern "C" {
#endif
typedef struct {
float *dst;
const float *activation;
const __fp16 *permuted_weight;
int m;
int k;
int n;
int act_stride;
int weight_stride;
int dst_stride;
int ne02;
int ne03;
int ne12;
int ne13;
size_t src0_nb2;
size_t src0_nb3;
size_t src1_nb2;
size_t src1_nb3;
size_t dst_nb2;
size_t dst_nb3;
} hmx_matmul_f16_f32_batched_params_t;
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
// act_stride: activation row stride in elements (= k for contiguous, or
// nb[1]/sizeof(float) for permuted tensors like attention Q).
// weight_stride: weight row stride in elements (= k for compact weights, or
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
int hmx_matmul_f16_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const __fp16 *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride);
// Batched F16 wrapper over hmx_mat_mul_f16_f32.
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params);
// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4)
int hmx_matmul_2d_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int act_stride,
int weight_stride,
int weight_type);
struct mmid_row_mapping;
int hmx_matmul_id_2d_f32(struct htp_context *ctx,
float *restrict dst,
const float *activation,
const uint8_t *permuted_weight,
int m, int k, int n,
int ne11,
size_t act_nb1, size_t act_nb2,
size_t dst_nb1, size_t dst_nb2,
int weight_stride,
int weight_type,
const struct mmid_row_mapping *matrix_rows,
int cur_a,
int mapping_stride);
// HMX flash attention
int hmx_flash_attn_ext(struct htp_ops_context * octx);
#ifdef __cplusplus
}
#endif
#endif // HMX_OPS_H
+9 -3
View File
@@ -13,7 +13,9 @@
#include <stdint.h>
#include <stdbool.h>
#ifndef HTP_MAX_NTHREADS
#define HTP_MAX_NTHREADS 10
#endif
#define HTP_MAX_MMAPS 16
// Memory mapping
@@ -42,9 +44,13 @@ struct htp_ops_context {
enum htp_op_code op; // FIXME: rename to opcode
int32_t op_params[HTP_OP_MAX_PARAMS];
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS];
const struct htp_tensor * src[HTP_OP_MAX_INPUTS];
const struct htp_tensor * dst;
union {
const struct htp_tensor * dst;
const struct htp_tensor * dsts[HTP_OP_MAX_OUTPUTS];
};
// TODO convert these to an array
struct htp_spad src0_spad;
@@ -87,13 +93,13 @@ struct htp_context {
struct htp_ops_context octx;
#ifdef HTP_HAS_HMX
struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap
#endif
};
int op_matmul(struct htp_ops_context * octx);
int op_matmul_id(struct htp_ops_context * octx);
int op_matmul_qkv(struct htp_ops_context * octx);
int op_matmul_ffn(struct htp_ops_context * octx);
int op_binary(struct htp_ops_context * octx);
int op_unary(struct htp_ops_context * octx);
int op_sum_rows(struct htp_ops_context * octx);
+15 -8
View File
@@ -28,18 +28,19 @@ enum htp_data_type {
HTP_TYPE_MXFP4 = 39,
// types used internally for repack, dyn.quant, etc
HTP_TYPE_Q4_0x4x2 = 200,
HTP_TYPE_Q4_1x4x2,
HTP_TYPE_Q8_0x4x2,
HTP_TYPE_MXFP4x4x2,
HTP_TYPE_Q4_0_TILED = 200,
HTP_TYPE_Q4_1_TILED,
HTP_TYPE_Q8_0_TILED,
HTP_TYPE_MXFP4_TILED,
HTP_TYPE_INVALID
};
// Constats for internal types
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
#define QK_Q4_0_TILED 256 // 32x32 Q4_0 tiled layout
#define QK_Q8_0_TILED 128 // 32x32 Q8_0 tiled layout
#define QK_MXFP4_TILED 256 // 32x32 MXFP4 tiled layout
// Mask to enable various stages of the Ops.
@@ -57,6 +58,8 @@ enum htp_op_code {
HTP_OP_DIV = 3,
HTP_OP_MUL_MAT,
HTP_OP_MUL_MAT_ID,
HTP_OP_MUL_MAT_QKV,
HTP_OP_MUL_MAT_FFN,
HTP_OP_RMS_NORM,
HTP_OP_RMS_NORM_MUL,
HTP_OP_UNARY_SILU,
@@ -99,7 +102,9 @@ enum htp_op_code {
#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS
#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS
#define HTP_OP_MAX_OUTPUTS 4
#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS
#define HTP_OP_MAX_KERN_PARAMS 32
#define HTP_OP_MAX_BUFS 16
#define HTP_OP_MAX_REQS 256
@@ -142,8 +147,10 @@ struct htp_op_desc {
uint32_t opcode; // GGML/HTP Op
uint32_t flags; // Op flags
int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; // generic blob for host-precomputed parameters
uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices
uint16_t dst; // Output tensor index
uint16_t dst[HTP_OP_MAX_OUTPUTS]; // Output tensor indices
uint16_t pad[2]; // padding to align to 64 bits
};
#ifndef HTP_MAX_NTHREADS
+2 -1
View File
@@ -11,12 +11,13 @@ struct htp_iface_pmu_conf {
};
interface htp_iface : remote_handle64 {
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem);
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 n_hmx, in uint64 max_vmem);
AEEResult stop();
AEEResult mmap(in uint32 fd, in uint32 size);
AEEResult munmap(in uint32 fd);
AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu);
AEEResult etm(in uint32 enable);
AEEResult hwinfo(rout uint32 n_threads, rout uint32 n_hvx, rout uint32 n_hmx, rout uint64 vtcm_size);
};
#endif /* HTP_IDL */
+13 -18
View File
@@ -170,25 +170,7 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
}
#endif
/* Q6_Vsf_equals_Vw is only available on v73+.*/
#if __HVX_ARCH__ < 73
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
{
HVX_Vector const vzero = Q6_V_vzero();
HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
return ret;
}
static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
{
return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));
}
#endif
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
// This looks complicated.
@@ -305,4 +287,17 @@ static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) {
#endif // __HVX_ARCH__ < 79
static inline HVX_Vector hvx_vec_load_act_tile(const uint8_t * y_q, uint32_t kt, HVX_Vector * v_act_all) {
if (kt % 4 == 0) {
*v_act_all = hvx_vmem(y_q + kt * 32);
return *v_act_all;
} else if (kt % 4 == 1) {
return Q6_V_vror_VR(*v_act_all, 32);
} else if (kt % 4 == 2) {
return Q6_V_vror_VR(*v_act_all, 64);
} else {
return Q6_V_vror_VR(*v_act_all, 96);
}
}
#endif /* HVX_BASE_H */
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+81 -23
View File
@@ -361,7 +361,7 @@ static void vtcm_free(struct htp_context * ctx) {
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
static void htp_error_callback(dspqueue_t queue, int error, void * context);
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) {
AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp_queue_id, uint32_t n_hvx, uint32_t n_hmx, uint64_t max_vmem) {
struct htp_context * ctx = (struct htp_context *) handle;
if (!ctx) {
@@ -395,10 +395,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
return AEE_ENOMEMORY;
}
#ifdef HTP_HAS_HMX
ctx->hmx_enabled = use_hmx;
ctx->hmx_enabled = n_hmx;
ctx->hmx_queue = NULL;
if (use_hmx) {
if (n_hmx) {
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
if (ctx->hmx_queue) {
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
@@ -407,8 +406,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
ctx->hmx_enabled = false;
}
}
FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx);
#endif
FARF(HIGH, "HMX %s (n_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", n_hmx);
qurt_sysenv_max_hthreads_t hw_threads;
qurt_sysenv_get_max_hw_threads(&hw_threads);
@@ -481,13 +479,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
dma_queue_delete(ctx->dma[i]);
}
#ifdef HTP_HAS_HMX
if (ctx->hmx_queue) {
hmx_queue_delete(ctx->hmx_queue);
ctx->hmx_queue = NULL;
}
ctx->hmx_enabled = false;
#endif
vtcm_free(ctx);
@@ -500,6 +496,36 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
return AEE_SUCCESS;
}
AEEResult htp_iface_hwinfo(remote_handle64 handle, uint32_t * n_threads, uint32_t * n_hvx, uint32_t * n_hmx, uint64_t * vtcm_size) {
(void)handle;
if (!n_threads || !n_hvx || !n_hmx || !vtcm_size) {
return AEE_EBADPARM;
}
qurt_sysenv_max_hthreads_t hw_threads;
qurt_sysenv_get_max_hw_threads(&hw_threads);
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
uint32_t n_hvx_val = hw_nhvx;
if (n_hvx_val > hw_threads.max_hthreads) {
n_hvx_val = hw_threads.max_hthreads;
}
if (n_hvx_val > HTP_MAX_NTHREADS) {
n_hvx_val = HTP_MAX_NTHREADS;
}
// for now we force n_threads == n_hvx
*n_threads = n_hvx_val;
*n_hvx = n_hvx_val;
*n_hmx = 1;
uint32_t vtcm_sz = 8 * 1024 * 1024; // 8MB default fallback
HAP_compute_res_query_VTCM(0, (unsigned int *)&vtcm_sz, NULL, NULL, NULL);
*vtcm_size = vtcm_sz;
return AEE_SUCCESS;
}
static void htp_error_callback(dspqueue_t queue, int error, void * context) {
// No errors expected on the DSP.
FARF(ERROR, "Error callback: 0x%08x", (unsigned) error);
@@ -554,6 +580,12 @@ static int execute_op(struct htp_ops_context * octx) {
case HTP_OP_MUL_MAT_ID:
return op_matmul_id(octx);
case HTP_OP_MUL_MAT_QKV:
return op_matmul_qkv(octx);
case HTP_OP_MUL_MAT_FFN:
return op_matmul_ffn(octx);
case HTP_OP_MUL:
case HTP_OP_ADD:
case HTP_OP_SUB:
@@ -762,8 +794,9 @@ static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, str
}
}
static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
static int proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
memcpy(octx->op_params, op->params, sizeof(octx->op_params));
memcpy(octx->kernel_params, op->kernel_params, sizeof(octx->kernel_params));
octx->flags = op->flags;
octx->op = op->opcode;
@@ -785,22 +818,41 @@ static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens,
src->ne[0], src->ne[1], src->ne[3], src->ne[3]);
}
// Prep output tensor
struct htp_tensor *dst = tens + op->dst;
// Prep output tensors
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
uint16_t dst_idx = op->dst[i];
if (dst_idx == 0xffff) {
octx->dsts[i] = NULL;
continue;
}
struct htp_tensor *dst = tens + dst_idx;
octx->dsts[i] = dst;
octx->dst = dst;
FARF(HIGH, "prep-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, dst_idx, (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
}
FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
int status = execute_op(octx);
(void) execute_op(octx);
octx->src0_spad.src = NULL;
octx->src1_spad.src = NULL;
octx->src2_spad.src = NULL;
octx->src3_spad.src = NULL;
octx->dst_spad.src = NULL;
// flush buffers on output
hex_l2flush((void *) dst->data, dst->size);
dst->flags |= HTP_TENSOR_FLUSHED;
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
if (octx->dsts[i]) {
struct htp_tensor *dst = (struct htp_tensor *)octx->dsts[i];
hex_l2flush((void *) dst->data, dst->size);
dst->flags |= HTP_TENSOR_FLUSHED;
FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
FARF(HIGH, "post-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, op->dst[i], (void*) dst->data, dst->size,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
}
}
return status;
}
#define DSPQUEUE_POLL_TIMEOUT_USEC 100
@@ -892,20 +944,26 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
}
}
int op_status = HTP_STATUS_OK;
uint32_t op_wakeup = n_ops / 2; // half-way throgh the batch
for (uint32_t i=0; i < n_ops; i++) {
struct profile_data prof;
if (i == (n_ops-1)) {
// wake up the host before starting the last op
if (i == op_wakeup) {
dspqueue_write_early_wakeup_noblock(queue, 0, 0);
}
profile_start(ctx->profiler, &prof);
proc_op_req(octx, tens, i, &ops[i]);
op_status = proc_op_req(octx, tens, i, &ops[i]);
profile_stop(ctx->profiler, &prof);
if (op_status != HTP_STATUS_OK) {
break;
}
if (ctx->profiler) {
pds[i].opcode = ops[i].opcode;
pds[i].usecs = prof.usecs;
@@ -919,7 +977,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_opbatch_rsp rsp;
rsp.id = req.id;
rsp.status = HTP_STATUS_OK;
rsp.status = op_status;
rsp.n_bufs = n_bufs;
rsp.n_tensors = n_tens;
rsp.n_ops = n_ops;
File diff suppressed because it is too large Load Diff
+508
View File
@@ -0,0 +1,508 @@
#ifndef HTP_MATMUL_OPS_H
#define HTP_MATMUL_OPS_H
#include <stdint.h>
#include <stddef.h>
#include "htp-ops.h"
#include "hex-fastdiv.h"
#include "hex-common.h"
#ifdef __cplusplus
extern "C" {
#endif
// --- HMX Tile Constraints ---
#define HTP_MM_HMX_TILE_N_COLS 32
#define HTP_MM_HMX_TILE_N_ROWS 32
#define HTP_MM_HMX_TILE_SIZE (32 * 32 * sizeof(__fp16)) // 2048 bytes
#define HTP_MM_HMX_TILE_N_ELMS 1024
#define HTP_MM_HMX_MIN_NROWS 4
// --- Weight Repacked Tile Sizes ---
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_0 576
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_1 640
#define HTP_MM_WEIGHT_TILE_SIZE_Q8_0 1088
#define HTP_MM_WEIGHT_TILE_SIZE_IQ4_NL 576
#define HTP_MM_WEIGHT_TILE_SIZE_MXFP4 544
// --- Weight Repacked Aligned Tile Sizes ---
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0 640
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1 640
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0 1152
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_IQ4_NL 640
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4 640
// --- Activation Tiled Block Sizes (including padding) ---
#define HTP_MM_ACT_TILE_SIZE_Q8_0 1152
#define HTP_MM_ACT_TILE_SIZE_Q8_1 1280
#define HTP_MM_MAX_PREFETCH 16
// --- Solver Cost Model Penalty Weights (HMX-specific) ---
#define HTP_MM_HMX_COST_W_DEQUANT 3 // cost penalty for quantized weight loading/dequantization
#define HTP_MM_HMX_COST_A_CONVERT 2 // cost penalty for activation loading/conversion
// --- DMA Activation Transfer Configuration ---
#define HTP_MM_DMA_ACT_ROWS_PER_STEP 2
#define HTP_MM_DMA_ACT_MULTIPLIER 4
enum htp_mm_kernel_type {
HTP_MM_KERNEL_UNSUPPORTED = 0,
// HMX paths
HTP_MM_KERNEL_HMX_2D,
HTP_MM_KERNEL_HMX_F16_BATCHED,
// HVX floating-point paths
HTP_MM_KERNEL_HVX_F16_F16_VTCM,
HTP_MM_KERNEL_HVX_F16_F16_DDR,
HTP_MM_KERNEL_HVX_F16_F32_DDR,
HTP_MM_KERNEL_HVX_F32_F32_VTCM,
HTP_MM_KERNEL_HVX_F32_F32_DDR,
HTP_MM_KERNEL_HVX_F32_F16_DDR,
// HVX quantized paths
HTP_MM_KERNEL_HVX_QUANT_ROW, // standard row-wise parallel quantization
HTP_MM_KERNEL_HVX_QUANT_BLOCK, // parallel block-wise quantization
HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT, // row-wise fallback flat quantization
};
// Op-specific struct for precomputed matmul params
struct htp_mm_kernel_params {
int32_t kernel_type; // enum htp_mm_kernel_type
int32_t pipeline; // 1 = pipelined execution, 0 = standard
int32_t m_chunk; // Row chunk size (M chunk)
int32_t n_chunk; // Col chunk size (N chunk)
int32_t n_threads; // Number of threads to spawn
int32_t n_act_threads; // Number of threads for activation preparation
int32_t n_hmx; // 1 = use HMX, 0 = use HVX
int32_t n_prefetch; // Prefetch lookahead buffers/rows in VTCM
int32_t tile_size; // Weight tile size
int32_t aligned_tile_size; // Aligned weight tile size (padded to 128)
int32_t src1_row_size; // Row size for quantized activation
int32_t vtcm_size; // Total required scratchpad size in VTCM
int32_t vtcm_src0_size; // src0 scratchpad size in VTCM
int32_t vtcm_src1_size; // src1 scratchpad size in VTCM
int32_t vtcm_src2_size; // src2 scratchpad size in VTCM (fused only)
int32_t vtcm_src3_size; // src3 scratchpad size in VTCM (fused only)
int32_t vtcm_dst_size; // dst scratchpad size in VTCM
// Precomputed division values
struct fastdiv_values div_ne12_ne1;
struct fastdiv_values div_ne1;
struct fastdiv_values div_r2;
struct fastdiv_values div_r3;
struct fastdiv_values div_ne11;
};
#if defined(__cplusplus)
static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
#else
_Static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
#endif
struct mmid_row_mapping {
uint32_t i1;
uint32_t i2;
};
// Search for optimal (mc, nc) chunk sizes within VTCM budget.
static inline int htp_mm_hmx_compute_chunks(size_t vtcm_total,
size_t overhead,
size_t per_n_cost,
size_t per_m_cost,
size_t per_mn_cost,
size_t m,
size_t n,
size_t m_block_cost,
size_t n_block_cost,
size_t * m_chunk_out,
size_t * n_chunk_out,
size_t * total_out) {
if (m == 0 || n == 0) return -1;
if (vtcm_total <= overhead) return -1;
if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1;
const size_t usable = vtcm_total - overhead;
size_t best_cost = SIZE_MAX;
size_t best_mn = 0;
size_t best_m = 0, best_n = 0;
const size_t n_max = hex_align_down((size_t)n, HTP_MM_HMX_TILE_N_COLS);
for (size_t nc = n_max; nc >= HTP_MM_HMX_TILE_N_COLS; nc -= HTP_MM_HMX_TILE_N_COLS) {
size_t n_fixed = 0, ncmn = 0, mc_denom = 0;
if (hex_mul_overflow(nc, per_n_cost, &n_fixed)) continue;
if (n_fixed >= usable) goto next_nc;
if (hex_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc;
if (hex_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc;
{
size_t remain = usable - n_fixed;
size_t mc = remain / mc_denom;
mc = hex_align_down(mc, HTP_MM_HMX_TILE_N_ROWS);
mc = hex_smin(mc, m);
if (mc == 0) {
goto next_nc;
}
size_t mblocks = ((size_t) m + mc - 1) / mc;
size_t nblocks = ((size_t) n + nc - 1) / nc;
size_t cost = mblocks * m_block_cost + nblocks * n_block_cost;
size_t mn = mc * nc;
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
best_cost = cost;
best_mn = mn;
best_m = mc;
best_n = nc;
}
}
next_nc:
if (nc == HTP_MM_HMX_TILE_N_COLS) break; // avoid size_t underflow
}
if (best_m == 0 || best_n == 0) return -1;
// Compute exact total (with overflow checks)
size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0;
if (hex_mul_overflow(best_n, per_n_cost, &t0)) return -1;
if (hex_mul_overflow(best_m, per_m_cost, &t1)) return -1;
if (hex_mul_overflow(best_m, best_n, &mn)) return -1;
if (hex_mul_overflow(mn, per_mn_cost, &t2)) return -1;
if (hex_add_overflow(t0, t1, &total)) return -1;
if (hex_add_overflow(total, t2, &total)) return -1;
if (hex_add_overflow(total, overhead, &total)) return -1;
*m_chunk_out = best_m;
*n_chunk_out = best_n;
*total_out = total;
return 0;
}
// --- Tile Size Helpers ---
static inline uint32_t htp_mm_get_weight_tile_size(int weight_type) {
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
return HTP_MM_WEIGHT_TILE_SIZE_Q4_0;
case HTP_TYPE_Q4_1:
return HTP_MM_WEIGHT_TILE_SIZE_Q4_1;
case HTP_TYPE_Q8_0:
return HTP_MM_WEIGHT_TILE_SIZE_Q8_0;
case HTP_TYPE_MXFP4:
return HTP_MM_WEIGHT_TILE_SIZE_MXFP4;
default:
return 0;
}
}
static inline uint32_t htp_mm_get_weight_aligned_tile_size(int weight_type) {
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0;
case HTP_TYPE_Q4_1:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1;
case HTP_TYPE_Q8_0:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0;
case HTP_TYPE_MXFP4:
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4;
default:
return 0;
}
}
// --- Activation/Row Size Helpers ---
static inline size_t htp_mm_q8_0_tiled_row_size(uint32_t ne) {
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
const uint32_t nb_32 = ne_padded / 32;
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_0;
}
static inline size_t htp_mm_q8_1_tiled_row_size(uint32_t ne) {
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
const uint32_t nb_32 = ne_padded / 32;
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_1;
}
static inline size_t htp_mm_q8_0_flat_row_size(uint32_t ne) {
const uint32_t quants_size = hex_align_up(ne, 128);
const uint32_t num_scales = (ne + 31) / 32;
const uint32_t scales_size = hex_align_up(num_scales * 2, 128);
return quants_size + scales_size;
}
static inline size_t htp_mm_q8_1_flat_row_size(uint32_t ne) {
const uint32_t quants_size = hex_align_up(ne, 128);
const uint32_t num_scales = (ne + 31) / 32;
const uint32_t scales_size = hex_align_up(num_scales * 4, 128);
return quants_size + scales_size;
}
static inline size_t htp_mm_get_tiled_row_stride(int weight_type, uint32_t k) {
uint32_t nb = (k + QK_Q4_0_TILED - 1) / QK_Q4_0_TILED;
switch (weight_type) {
case HTP_TYPE_Q4_0:
case HTP_TYPE_IQ4_NL:
case HTP_TYPE_Q4_1:
case HTP_TYPE_Q8_0:
case HTP_TYPE_MXFP4:
return (size_t) nb * htp_mm_get_weight_tile_size(weight_type);
case HTP_TYPE_F16:
return (size_t) k * sizeof(__fp16);
case HTP_TYPE_F32:
return (size_t) k * sizeof(float);
default:
return 0;
}
}
static inline size_t htp_mm_round_up(size_t n, size_t m) {
return ((n + m - 1) / m) * m;
}
static inline bool htp_mm_hmx_pipeline(uint32_t m) {
return m > 32;
}
static inline void htp_mm_hmx_get_2d_chunk_costs(
int wtype, uint32_t k, bool pipeline, uint32_t aligned_tile_size,
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
) {
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
const size_t vec_dot_size = k * sizeof(uint16_t);
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0;
*size_per_n_out = (pipeline ? 2 : 1) * (is_quant ? qweight_row_stride : row_stride) +
(pipeline ? 2 * vec_dot_size : vec_dot_size);
*size_per_m_out = vec_dot_size;
*size_per_mn_out = (pipeline ? 2 : 1) * sizeof(uint16_t);
}
static inline void htp_mm_hmx_get_batched_chunk_costs(
uint32_t k, uint32_t group_size,
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
) {
const size_t vec_dot_size = k * sizeof(uint16_t);
*size_per_n_out = 3 * vec_dot_size;
*size_per_m_out = group_size * vec_dot_size;
*size_per_mn_out = sizeof(uint16_t);
}
static inline size_t htp_mm_hmx_get_2d_vtcm_size(
int wtype, uint32_t k, size_t mc, size_t nc, bool pipeline, uint32_t act_threads, uint32_t aligned_tile_size
) {
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
const size_t vec_dot_size = k * sizeof(uint16_t);
const size_t act_f32_size = htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE);
size_t weight_area_size = is_quant
? htp_mm_round_up((nc / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE)
: htp_mm_round_up(nc * row_stride, HTP_MM_HMX_TILE_SIZE);
if (pipeline) {
weight_area_size *= 2;
}
const size_t act_area_size = htp_mm_round_up(mc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
const size_t output_area_size = htp_mm_round_up(mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
size_t scratch0_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
size_t scratch1_size = pipeline ? scratch0_size : 0;
size_t scratch2_size = pipeline ? output_area_size : 0;
return weight_area_size + act_area_size + act_f32_size + output_area_size +
scratch0_size + scratch1_size + scratch2_size + 256;
}
static inline size_t htp_mm_hmx_get_batched_vtcm_size(
int wtype, uint32_t k, size_t mc, size_t nc, uint32_t group_size, bool use_dma_activation, bool pipeline, uint32_t act_threads) {
(void)wtype;
(void)pipeline;
const size_t vec_dot_size = k * sizeof(uint16_t);
const size_t f32_scratch_size = use_dma_activation
? htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0;
const size_t act_head_stride = mc * k;
const size_t weight_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
const size_t act_area_size = htp_mm_round_up(group_size * act_head_stride * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
const size_t output_area_size = htp_mm_round_up(group_size * mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
const size_t scratch_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
return weight_area_size + act_area_size + output_area_size +
2 * scratch_area_size + 256 + f32_scratch_size;
}
static inline size_t htp_mm_hvx_get_vtcm_sizes(
int kernel_type,
int wtype,
uint32_t ne10, // k
uint32_t src1_nrows, // m_total (or act_nrows)
uint32_t n_threads,
size_t dst_row_size,
size_t src0_row_size,
size_t src1_row_size,
uint32_t n_prefetch,
size_t * vtcm_src0_size_out,
size_t * vtcm_src1_size_out,
size_t * vtcm_dst_size_out
) {
size_t vtcm_src0_size = 0;
size_t vtcm_src1_size = 0;
size_t vtcm_dst_size = 0;
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
wtype == HTP_TYPE_MXFP4);
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
const size_t dst_nrows = (src1_nrows > 1) ? 0 : 1;
switch (kernel_type) {
case HTP_MM_KERNEL_HVX_F16_F16_VTCM: {
size_t f16_src1_row_size = htp_mm_round_up(ne10 * 2, 128);
vtcm_src1_size = htp_mm_round_up(f16_src1_row_size * src1_nrows, 256);
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
break;
}
case HTP_MM_KERNEL_HVX_F16_F32_DDR:
case HTP_MM_KERNEL_HVX_F16_F16_DDR:
case HTP_MM_KERNEL_HVX_F32_F32_DDR:
case HTP_MM_KERNEL_HVX_F32_F16_DDR: {
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size, 256) * n_threads;
vtcm_src1_size = htp_mm_round_up(n_prefetch * src1_row_size, 256) * n_threads;
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
break;
}
case HTP_MM_KERNEL_HVX_F32_F32_VTCM: {
size_t f32_src1_row_size = htp_mm_round_up(ne10 * 4, 128);
vtcm_src1_size = htp_mm_round_up(f32_src1_row_size * src1_nrows, 256);
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
break;
}
case HTP_MM_KERNEL_HVX_QUANT_BLOCK:
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
// src0 spad is also used in dynamic quantizer to store padded src1 rows
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
break;
}
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
if (vtcm_src0_size < src1_row_size_padded) {
vtcm_src0_size = src1_row_size_padded;
}
vtcm_src0_size = vtcm_src0_size * n_threads;
vtcm_dst_size = vtcm_dst_size * n_threads;
if (is_repack) {
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
uint32_t n_k_tiles = ne10 / 32;
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
vtcm_src0_size = repacked_vtcm_size * n_threads;
}
break;
}
default:
break;
}
*vtcm_src0_size_out = vtcm_src0_size;
*vtcm_src1_size_out = vtcm_src1_size;
*vtcm_dst_size_out = vtcm_dst_size;
return vtcm_src0_size + vtcm_src1_size + vtcm_dst_size;
}
static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
int wtype,
uint32_t ne10, // k
uint32_t src1_nrows,
uint32_t n_threads,
size_t src0_row_size, // nb01
uint32_t n_prefetch,
size_t * vtcm_src0_size_out,
size_t * vtcm_src1_size_out
) {
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
wtype == HTP_TYPE_MXFP4);
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
const size_t src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10)
: htp_mm_q8_0_tiled_row_size(ne10);
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
if (src0_sz_per_thread < src1_row_size_padded) {
src0_sz_per_thread = src1_row_size_padded;
}
if (is_repack) {
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
const uint32_t n_k_tiles = ne10 / 32;
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
if (repacked_vtcm_size < src1_row_size_padded) {
repacked_vtcm_size = src1_row_size_padded;
}
src0_sz_per_thread = repacked_vtcm_size;
}
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
*vtcm_src0_size_out = vtcm_src0_size;
*vtcm_src1_size_out = src1_sz;
return vtcm_src0_size + src1_sz;
}
#ifdef __cplusplus
}
#endif
#endif // HTP_MATMUL_OPS_H
-4
View File
@@ -14,8 +14,6 @@ Drivers_Dir = 13
1 = %DiskId%
[SourceDisksFiles]
libggml-htp-v68.so = 1
libggml-htp-v69.so = 1
libggml-htp-v73.so = 1
libggml-htp-v75.so = 1
libggml-htp-v79.so = 1
@@ -28,8 +26,6 @@ ExcludeFromSelect = *
CopyFiles=Drivers_Dir
[Drivers_Dir]
libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE
+3
View File
@@ -192,7 +192,10 @@ set(GGML_OPENCL_KERNELS
mul_mm_f16_f32_kq_kqv
conv2d
conv2d_f16_f32
flash_attn_pre_f16
flash_attn_f32_f16
flash_attn_f32_q8_0
flash_attn_f32_q4_0
flash_attn_f16
flash_attn_f32
)
+91
View File
@@ -0,0 +1,91 @@
#pragma once
// Flash-attention per-(dk,dv) tile tuning for the Adreno OpenCL backend.
// Isolated from ggml-opencl.cpp so the tuning numbers are easy to find and
// edit; the FA dispatch and kernel-compile logic stay in the main file.
// This header is a file section — it is #included exactly once, at the point
// in ggml-opencl.cpp where the ggml logging macros are already in scope.
// Per-(dk, dv) FA config; shared by dispatch and supports_op.
struct ggml_opencl_fa_dim {
int dk; int dv; int bm; int bn; int n_split; int nkv_split_threshold;
};
// Split variant fires when n_kv >= threshold (threshold=0 -> always split).
// Default tuning covers Adreno 7xx/8xx mobile and X1-series laptop GPUs.
static const ggml_opencl_fa_dim g_fa_dims_adreno_default[] = {
{ 40, 40, 64, 32, 1, 0}, { 64, 64, 64, 32, 2, 64},
{ 80, 80, 64, 32, 2, 64}, { 96, 96, 64, 32, 2, 64},
{112, 112, 64, 32, 2, 64}, {128, 128, 64, 32, 2, 64},
{192, 128, 16, 16, 1, 0},
{192, 192, 16, 16, 1, 0},
{256, 256, 16, 16, 16, 0},
};
struct ggml_opencl_fa_dim_table {
const ggml_opencl_fa_dim * data;
size_t count;
const ggml_opencl_fa_dim * begin() const { return data; }
const ggml_opencl_fa_dim * end() const { return data + count; }
};
// Mutable copy of the active table; GGML_OPENCL_FA_TUNE patches entries here
// at backend init without touching the const source table.
static ggml_opencl_fa_dim g_fa_dims_runtime[
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0])];
static ggml_opencl_fa_dim_table g_opencl_fa_dims = {
g_fa_dims_adreno_default,
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]),
};
// GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,…] — patches matching entries
// in the active table at backend init, before the first FA kernel compiles.
// Unmatched (dk,dv) pairs are warned and ignored.
static void ggml_opencl_fa_apply_env_overrides() {
const char * e = std::getenv("GGML_OPENCL_FA_TUNE");
if (!e || !e[0]) {
return;
}
std::string s = e;
size_t pos = 0;
while (pos < s.size()) {
size_t comma = s.find(',', pos);
std::string entry = s.substr(pos, comma == std::string::npos ? std::string::npos : comma - pos);
int dk, dv, bm, bn, nsplit, thr;
if (std::sscanf(entry.c_str(), "%d:%d:%d:%d:%d:%d", &dk, &dv, &bm, &bn, &nsplit, &thr) == 6) {
bool patched = false;
for (size_t i = 0; i < g_opencl_fa_dims.count; ++i) {
ggml_opencl_fa_dim & d = g_fa_dims_runtime[i];
if (d.dk == dk && d.dv == dv) {
d.bm = bm; d.bn = bn; d.n_split = nsplit; d.nkv_split_threshold = thr;
GGML_LOG_INFO("ggml_opencl: FA tune override DK=%d DV=%d -> bm=%d bn=%d n_split=%d thr=%d\n",
dk, dv, bm, bn, nsplit, thr);
patched = true;
break;
}
}
if (!patched) {
GGML_LOG_WARN("ggml_opencl: FA tune override DK=%d DV=%d ignored (no matching dim)\n", dk, dv);
}
} else {
GGML_LOG_WARN("ggml_opencl: FA tune override entry malformed: '%s'\n", entry.c_str());
}
if (comma == std::string::npos) break;
pos = comma + 1;
}
}
// Copy the default table into the mutable runtime buffer and apply any
// GGML_OPENCL_FA_TUNE overrides. A per-generation table can be added here
// once it has been tuned on hardware.
static void ggml_cl_init_fa_dims_table() {
const size_t count = sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]);
for (size_t i = 0; i < count; ++i) {
g_fa_dims_runtime[i] = g_fa_dims_adreno_default[i];
}
g_opencl_fa_dims = { g_fa_dims_runtime, count };
ggml_opencl_fa_apply_env_overrides();
}
File diff suppressed because it is too large Load Diff
+152
View File
@@ -1582,6 +1582,158 @@ kernel void kernel_restore_block_q8_0(
}
}
// View-aware AoS q8_0 -> f32 dequant (f32/f32 FA path).
kernel void kernel_dequant_q8_0_f32_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global float * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
float d = vload_half(0, (global half *)block);
global char * qs = block + 2;
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global float * out = dst + (dst_row_base + blk_i0) * QK8_0;
for (int i = 0; i < QK8_0; ++i) {
out[i] = d * (float)qs[i];
}
}
// View-aware AoS q8_0 -> f16 dequant. Rows tight, batch strides may be gapped.
kernel void kernel_dequant_q8_0_f16_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global half * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
float d = vload_half(0, (global half *)block);
global char * qs = block + 2;
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global half * out = dst + (dst_row_base + blk_i0) * QK8_0;
for (int i = 0; i < QK8_0; ++i) {
out[i] = (half)(d * (float)qs[i]);
}
}
// View-aware AoS q4_0 -> f32 dequant (mirrors the q8_0 view variant).
kernel void kernel_dequant_q4_0_f32_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global float * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
float d = vload_half(0, (global half *)block);
global uchar * qs = (global uchar *)(block + 2);
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global float * out = dst + (dst_row_base + blk_i0) * QK4_0;
for (int i = 0; i < QK4_0/2; ++i) {
uchar byte = qs[i];
int q0 = (int)(byte & 0x0F) - 8;
int q1 = (int)(byte >> 4) - 8;
out[i] = d * (float)q0;
out[i + QK4_0/2] = d * (float)q1;
}
}
// View-aware AoS q4_0 -> f16 dequant (mirrors the q8_0 view variant).
kernel void kernel_dequant_q4_0_f16_view_aos(
global char * src,
ulong src_offset,
ulong src_nb1,
ulong src_nb2,
ulong src_nb3,
int nblk0,
int ne1,
int ne2,
int ne3,
global half * dst
) {
int blk_i0 = get_global_id(0);
int i1 = get_global_id(1);
int batch = get_global_id(2);
if (blk_i0 >= nblk0) return;
if (i1 >= ne1) return;
int i2 = batch % ne2;
int i3 = batch / ne2;
if (i3 >= ne3) return;
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
float d = vload_half(0, (global half *)block);
global uchar * qs = (global uchar *)(block + 2);
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
global half * out = dst + (dst_row_base + blk_i0) * QK4_0;
for (int i = 0; i < QK4_0/2; ++i) {
uchar byte = qs[i];
int q0 = (int)(byte & 0x0F) - 8;
int q1 = (int)(byte >> 4) - 8;
out[i] = (half)(d * (float)q0);
out[i + QK4_0/2] = (half)(d * (float)q1);
}
}
kernel void kernel_restore_block_q8_0_trans(
global uchar * src_q,
global half * src_d,
+75 -40
View File
@@ -4,14 +4,26 @@
#define ACC_TYPE4 float4
#define DATA_TYPE half
#define DATA_TYPE4 half4
#define CONVERT_ACC4(x) convert_float4(x)
#define CONVERT_DATA4(x) convert_half4(x)
#define CONVERT_ACC4(x) ((float4)((float)(x).s0, (float)(x).s1, (float)(x).s2, (float)(x).s3))
#define CONVERT_DATA4(x) ((half4)((half)(x).s0, (half)(x).s1, (half)(x).s2, (half)(x).s3))
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -81,18 +93,18 @@ __kernel void flash_attn_f16(
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -125,49 +137,72 @@ __kernel void flash_attn_f16(
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL) {
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
const ACC_TYPE p0 = native_exp(s0 - m_new);
const ACC_TYPE p1 = native_exp(s1 - m_new);
const ACC_TYPE p2 = native_exp(s2 - m_new);
const ACC_TYPE p3 = native_exp(s3 - m_new);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1;
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
@@ -179,7 +214,7 @@ __kernel void flash_attn_f16(
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -191,12 +226,12 @@ __kernel void flash_attn_f16(
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
@@ -258,7 +293,7 @@ __kernel void flash_attn_f16_q1(
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
@@ -270,12 +305,12 @@ __kernel void flash_attn_f16_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -293,7 +328,7 @@ __kernel void flash_attn_f16_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -301,7 +336,7 @@ __kernel void flash_attn_f16_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -311,7 +346,7 @@ __kernel void flash_attn_f16_q1(
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -325,7 +360,7 @@ __kernel void flash_attn_f16_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -335,7 +370,7 @@ __kernel void flash_attn_f16_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -354,7 +389,7 @@ __kernel void flash_attn_f16_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -364,7 +399,7 @@ __kernel void flash_attn_f16_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}
+73 -38
View File
@@ -13,6 +13,18 @@
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -82,18 +94,18 @@ __kernel void flash_attn_f32(
if (my_query_row < n_q) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -126,49 +138,72 @@ __kernel void flash_attn_f32(
continue;
}
for (int j = 0; j < BLOCK_N; j += 2) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
const ACC_TYPE p0 = native_exp(s0 - m_new);
const ACC_TYPE p1 = native_exp(s1 - m_new);
const ACC_TYPE p2 = native_exp(s2 - m_new);
const ACC_TYPE p3 = native_exp(s3 - m_new);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1;
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
@@ -180,7 +215,7 @@ __kernel void flash_attn_f32(
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -192,12 +227,12 @@ __kernel void flash_attn_f32(
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (DATA_TYPE4)(0.0f);
}
@@ -259,7 +294,7 @@ __kernel void flash_attn_f32_q1(
ACC_TYPE4 q_priv[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
}
@@ -271,12 +306,12 @@ __kernel void flash_attn_f32_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -294,7 +329,7 @@ __kernel void flash_attn_f32_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -302,7 +337,7 @@ __kernel void flash_attn_f32_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -312,7 +347,7 @@ __kernel void flash_attn_f32_q1(
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
}
@@ -326,7 +361,7 @@ __kernel void flash_attn_f32_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -336,7 +371,7 @@ __kernel void flash_attn_f32_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -355,7 +390,7 @@ __kernel void flash_attn_f32_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -365,7 +400,7 @@ __kernel void flash_attn_f32_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
}
}
@@ -1,5 +1,13 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_khr_subgroup_shuffle
#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable
#define HAS_SUBGROUP_SHUFFLE 1
#elif defined(cl_qcom_subgroup_shuffle)
#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
#define HAS_SUBGROUP_SHUFFLE 1
#endif
#define ACC_TYPE float
#define ACC_TYPE4 float4
#define Q_DATA_TYPE4 float4
@@ -12,9 +20,34 @@
#define DK_VEC (DK/4)
#define DV_VEC (DV/4)
#define WG_SIZE (BLOCK_M)
#define Q1_WG_SIZE 64
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
// infinite operand can cause undefined behavior and miscompilation for exp.
// Therefore, a large negative value is used instead.
#define FA_M_INIT (-3.0e38f)
// Drop full unroll at DK>=192 Adreno compiler host-memory budget.
#if DK >= 192
#define FA_UNROLL
#else
#define FA_UNROLL _Pragma("unroll")
#endif
// N_SPLIT>1 splits DK/DV across threads to cut per-thread register use.
#ifndef N_SPLIT
#define N_SPLIT 1
#endif
#define SPLIT_DK_VEC (DK_VEC / N_SPLIT)
#define SPLIT_DV_VEC (DV_VEC / N_SPLIT)
#if N_SPLIT > 1
#define WG_SIZE (BLOCK_M * N_SPLIT)
#else
#define WG_SIZE (BLOCK_M)
#endif
inline float get_alibi_slope(
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
) {
@@ -54,19 +87,38 @@ __kernel void flash_attn_f32_f16(
const int mask_ne2,
const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
const ulong sinks_offset,
const global void * k_pad_void,
const global void * v_pad_void,
const global void * mask_pad_void,
const global char * blk,
const int n_kv_blocks,
const ulong mask_pad_nb1,
const ulong mask_pad_nb2,
const ulong mask_pad_nb3
) {
const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0);
const int head_batch_idx = get_global_id(1);
const int my_query_row = block_q_idx * BLOCK_M + tid;
#if N_SPLIT > 1
const int q_lane = tid / N_SPLIT;
const int split_idx = tid % N_SPLIT;
#else
const int q_lane = tid;
const int split_idx = 0;
#endif
const int my_query_row = block_q_idx * BLOCK_M + q_lane;
const int query_valid = my_query_row < n_q;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0;
const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0;
const global char* q_base = (const global char*)q_void + q_offset;
const global char* k_base = (const global char*)k_void + k_offset;
@@ -75,27 +127,41 @@ __kernel void flash_attn_f32_f16(
const global char* mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
const global char* mask_pad_base = NULL;
if (mask_pad_void != NULL) {
mask_pad_base = (const global char*)mask_pad_void + mask_batch_idx * mask_pad_nb3 + mask_head_idx * mask_pad_nb2;
}
const global char* blk_base = NULL;
if (blk != NULL) {
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
blk_base = blk + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks + block_q_idx) * n_kv_blocks;
}
ACC_TYPE4 q_priv[DK_VEC];
if (my_query_row < n_q) {
ACC_TYPE4 q_priv[SPLIT_DK_VEC];
const int dk_off = split_idx * SPLIT_DK_VEC;
if (query_valid) {
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
FA_UNROLL
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[dk_off + i]);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
q_priv[i] = (ACC_TYPE4)(0.0f);
}
}
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
ACC_TYPE4 o_acc[SPLIT_DV_VEC];
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = (ACC_TYPE4)(0.0f);
}
ACC_TYPE m_i = -INFINITY;
ACC_TYPE m_i = FA_M_INIT;
ACC_TYPE l_i = 0.0f;
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -103,86 +169,369 @@ __kernel void flash_attn_f32_f16(
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
#if N_SPLIT > 1 && !defined(HAS_SUBGROUP_SHUFFLE)
__local ACC_TYPE local_partial[BLOCK_N][WG_SIZE];
__local ACC_TYPE local_p[BLOCK_M][BLOCK_N];
__local ACC_TYPE local_softmax_scale[BLOCK_M];
__local ACC_TYPE local_l_inv[BLOCK_M];
#endif
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
char blk_cur = 1;
if (blk_base != NULL) {
blk_cur = blk_base[k_start / BLOCK_N];
if (blk_cur == 0) continue;
}
const int use_kv_pad = k_pad_void != NULL && k_start + BLOCK_N > n_kv;
const int k_tile_start = use_kv_pad ? 0 : k_start;
const ulong k_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * k_nb1 : k_nb2;
const ulong k_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * k_tile_nb2 : k_nb3;
const ulong v_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * v_nb1 : v_nb2;
const ulong v_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * v_tile_nb2 : v_nb3;
const global char* k_tile_base = use_kv_pad ? (const global char*) k_pad_void : k_base;
const global char* v_tile_base = use_kv_pad ? (const global char*) v_pad_void : v_base;
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
const int row = i / DK_VEC;
const int col = i % DK_VEC;
const int k_row_idx = k_start + row;
if (k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
const int k_row_idx = k_tile_start + row;
if (use_kv_pad || k_row_idx < n_kv) {
const ulong k_row_offset = batch_idx * k_tile_nb3 + head_kv_idx * k_tile_nb2 + k_row_idx * k_nb1;
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_tile_base + k_row_offset))[col];
} else {
l_k[row][col] = (KV_DATA_TYPE4)(0.0h);
}
}
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
const int row = i / DV_VEC;
const int col = i % DV_VEC;
const int v_row_idx = k_start + row;
if (v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
const int v_row_idx = k_tile_start + row;
if (use_kv_pad || v_row_idx < n_kv) {
const ulong v_row_offset = batch_idx * v_tile_nb3 + head_kv_idx * v_tile_nb2 + v_row_idx * v_nb1;
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_tile_base + v_row_offset))[col];
} else {
l_v[row][col] = (KV_DATA_TYPE4)(0.0h);
}
}
barrier(CLK_LOCAL_MEM_FENCE);
if (my_query_row >= n_q) {
continue;
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
{
const int dv_off = split_idx * SPLIT_DV_VEC;
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE partial0 = 0.0f;
ACC_TYPE partial1 = 0.0f;
FA_UNROLL
for (int k = 0; k < SPLIT_DK_VEC; k++) {
const ACC_TYPE4 qk = q_priv[k];
ACC_TYPE4 dot0 = qk * CONVERT_KV_ACC4(l_k[j ][dk_off + k]);
ACC_TYPE4 dot1 = qk * CONVERT_KV_ACC4(l_k[j+1][dk_off + k]);
partial0 += dot0.s0 + dot0.s1 + dot0.s2 + dot0.s3;
partial1 += dot1.s0 + dot1.s1 + dot1.s2 + dot1.s3;
}
FA_UNROLL
for (int step = 1; step < N_SPLIT; step <<= 1) {
partial0 += sub_group_shuffle_xor(partial0, step);
partial1 += sub_group_shuffle_xor(partial1, step);
}
ACC_TYPE score0 = partial0 * scale;
ACC_TYPE score1 = partial1 * scale;
if (!query_valid) { score0 = FA_M_INIT; score1 = FA_M_INIT; }
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = FA_M_INIT;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = FA_M_INIT;
}
if (k_row0 >= n_kv) score0 = FA_M_INIT;
if (k_row1 >= n_kv) score1 = FA_M_INIT;
if (query_valid && mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
score0 += slope * (ACC_TYPE)mask_ptr[j];
score1 += slope * (ACC_TYPE)mask_ptr[j + 1];
} else {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
// far negative so the tile contributes 0, not exp(0)=1.
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE sp = native_exp(m_i - m_exp);
const ACC_TYPE p0 = native_exp(score0 - m_exp);
const ACC_TYPE p1 = native_exp(score1 - m_exp);
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = o_acc[i] * sp
+ p0 * CONVERT_KV_ACC4(l_v[j ][dv_off + i])
+ p1 * CONVERT_KV_ACC4(l_v[j+1][dv_off + i]);
}
l_i = l_i * sp + p0 + p1;
m_i = m_new;
}
}
for (int j = 0; j < BLOCK_N; j += 2) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; k++) {
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
#elif N_SPLIT > 1
// N_SPLIT>1 fallback (no shuffle): 3-phase local-memory reduction.
// Phase 1 partial dots for all BLOCK_N tokens.
for (int j = 0; j < BLOCK_N; ++j) {
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < SPLIT_DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][dk_off + k]), dot_acc);
}
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
if (is_causal) {
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
}
if (k_row0 >= n_kv) score0 = -INFINITY;
if (k_row1 >= n_kv) score1 = -INFINITY;
if (mask_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
}
if (logit_softcap > 0.0f) {
score0 = logit_softcap * tanh(score0 / logit_softcap);
score1 = logit_softcap * tanh(score1 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(score0, score1));
const ACC_TYPE p0 = exp(score0 - m_new);
const ACC_TYPE p1 = exp(score1 - m_new);
const ACC_TYPE scale_prev = exp(m_i - m_new);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
}
l_i = l_i * scale_prev + p0 + p1;
m_i = m_new;
local_partial[j][tid] =
dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3;
}
barrier(CLK_LOCAL_MEM_FENCE); // 1 barrier: partial dots visible
// Phase 2 split_idx==0 reduces partial sums and computes block softmax.
if (split_idx == 0) {
if (query_valid) {
ACC_TYPE m_new = m_i;
for (int j = 0; j < BLOCK_N; ++j) {
const int k_row = k_start + j;
ACC_TYPE score = 0.0f;
FA_UNROLL
for (int s = 0; s < N_SPLIT; s++) {
score += local_partial[j][q_lane * N_SPLIT + s];
}
score *= scale;
if (is_causal && k_row > (n_kv - n_q + my_query_row)) score = FA_M_INIT;
if (k_row >= n_kv) score = FA_M_INIT;
if (mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
score += slope * (ACC_TYPE)mask_ptr[j];
} else {
const global MASK_DATA_TYPE* mask_ptr =
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row < n_kv) score += slope * (ACC_TYPE)mask_ptr[k_row];
}
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_new = max(m_new, score);
local_p[q_lane][j] = score;
}
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE sp = native_exp(m_i - m_exp);
ACC_TYPE l_new = l_i * sp;
for (int j = 0; j < BLOCK_N; ++j) {
const ACC_TYPE p = native_exp(local_p[q_lane][j] - m_exp);
local_p[q_lane][j] = p;
l_new += p;
}
local_softmax_scale[q_lane] = sp;
l_i = l_new;
m_i = m_new;
} else {
local_softmax_scale[q_lane] = 1.0f;
for (int j = 0; j < BLOCK_N; ++j) local_p[q_lane][j] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
// Phase 3 V accumulate using broadcast probabilities.
{
const ACC_TYPE sp_block = local_softmax_scale[q_lane];
const int dv_off = split_idx * SPLIT_DV_VEC;
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] *= sp_block;
}
for (int j = 0; j < BLOCK_N; ++j) {
const ACC_TYPE p = local_p[q_lane][j];
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(l_v[j][dv_off + i]), o_acc[i]);
}
}
}
#else
// N_SPLIT==1: j+=4 unroll. Requires BLOCK_N % 4 == 0.
if (query_valid) {
for (int j = 0; j < BLOCK_N; j += 4) {
const int k_row0 = k_start + j;
const int k_row1 = k_start + j + 1;
const int k_row2 = k_start + j + 2;
const int k_row3 = k_start + j + 3;
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
const ACC_TYPE4 qk = q_priv[k];
dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k]), dot_acc2);
dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k]), dot_acc3);
}
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
if (is_causal) {
const int causal_limit = n_kv - n_q + my_query_row;
if (k_row0 > causal_limit) s0 = FA_M_INIT;
if (k_row1 > causal_limit) s1 = FA_M_INIT;
if (k_row2 > causal_limit) s2 = FA_M_INIT;
if (k_row3 > causal_limit) s3 = FA_M_INIT;
}
if (k_row0 >= n_kv) s0 = FA_M_INIT;
if (k_row1 >= n_kv) s1 = FA_M_INIT;
if (k_row2 >= n_kv) s2 = FA_M_INIT;
if (k_row3 >= n_kv) s3 = FA_M_INIT;
if (mask_base != NULL && blk_cur != 2) {
if (use_kv_pad && mask_pad_base != NULL) {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
s0 += slope * (ACC_TYPE)mask_ptr[j];
s1 += slope * (ACC_TYPE)mask_ptr[j + 1];
s2 += slope * (ACC_TYPE)mask_ptr[j + 2];
s3 += slope * (ACC_TYPE)mask_ptr[j + 3];
} else {
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
}
}
if (logit_softcap > 0.0f) {
s0 = logit_softcap * tanh(s0 / logit_softcap);
s1 = logit_softcap * tanh(s1 / logit_softcap);
s2 = logit_softcap * tanh(s2 / logit_softcap);
s3 = logit_softcap * tanh(s3 / logit_softcap);
}
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
// far negative so the tile contributes 0, not exp(0)=1.
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
const ACC_TYPE scale_prev = native_exp(m_i - m_exp);
const ACC_TYPE p0 = native_exp(s0 - m_exp);
const ACC_TYPE p1 = native_exp(s1 - m_exp);
const ACC_TYPE p2 = native_exp(s2 - m_exp);
const ACC_TYPE p3 = native_exp(s3 - m_exp);
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i]),
mad(p2, CONVERT_KV_ACC4(l_v[j+2][i]),
mad(p1, CONVERT_KV_ACC4(l_v[j+1][i]),
mad(p0, CONVERT_KV_ACC4(l_v[j][i]),
o_acc[i] * scale_prev))));
}
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
m_i = m_new;
}
}
#endif
// End of tile: every thread must finish reading l_k/l_v before the
// next iteration's load overwrites them (WAR hazard on local memory).
barrier(CLK_LOCAL_MEM_FENCE);
}
if (my_query_row < n_q) {
// Write output.
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
if (query_valid) {
ACC_TYPE sinks_sp = 1.0f;
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
sinks_sp = exp(m_i - m_final);
l_i = l_i * sinks_sp + exp(m_sink - m_final);
m_i = m_final;
}
const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
const int dv_off = split_idx * SPLIT_DV_VEC;
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_inv > 0.0f) {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#elif N_SPLIT > 1
if (split_idx == 0) {
ACC_TYPE sinks_sp = 1.0f;
if (query_valid && sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
sinks_sp = exp(m_i - m_final);
l_i = l_i * sinks_sp + exp(m_sink - m_final);
m_i = m_final;
}
local_softmax_scale[q_lane] = sinks_sp;
local_l_inv[q_lane] = (query_valid && l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
}
barrier(CLK_LOCAL_MEM_FENCE);
if (query_valid) {
const ACC_TYPE sinks_sp = local_softmax_scale[q_lane];
const ACC_TYPE l_inv = local_l_inv[q_lane];
const int dv_off = split_idx * SPLIT_DV_VEC;
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_inv > 0.0f) {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
}
} else {
FA_UNROLL
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#else
if (query_valid) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
@@ -194,17 +543,18 @@ __kernel void flash_attn_f32_f16(
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_i;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
}
} else {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) {
o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
}
#endif
}
__kernel void flash_attn_f32_f16_q1(
@@ -258,13 +608,16 @@ __kernel void flash_attn_f32_f16_q1(
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
}
ACC_TYPE4 q_priv[DK_VEC];
// Q is uniform across WG threads (n_q=1). Share via local memory to
// avoid per-thread q_priv[DK_VEC] dynamic-indexed private array that
// spills to DDR on Adreno.
__local ACC_TYPE4 q_shared[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
#pragma unroll
for (int i = 0; i < DK_VEC; ++i) {
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
barrier(CLK_LOCAL_MEM_FENCE);
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
@@ -273,14 +626,14 @@ __kernel void flash_attn_f32_f16_q1(
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
@@ -296,7 +649,7 @@ __kernel void flash_attn_f32_f16_q1(
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
@@ -304,7 +657,7 @@ __kernel void flash_attn_f32_f16_q1(
const ACC_TYPE m_final = local_m[0];
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
@@ -314,9 +667,9 @@ __kernel void flash_attn_f32_f16_q1(
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
FA_UNROLL
for (int k = 0; k < DK_VEC; k++) {
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
@@ -328,7 +681,7 @@ __kernel void flash_attn_f32_f16_q1(
}
const ACC_TYPE p = exp(score - m_final);
l_i += p;
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; i++) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
@@ -338,7 +691,7 @@ __kernel void flash_attn_f32_f16_q1(
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -357,7 +710,7 @@ __kernel void flash_attn_f32_f16_q1(
for (int i = 0; i < DV_VEC; i++) {
local_o_comp[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
FA_UNROLL
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
@@ -367,7 +720,257 @@ __kernel void flash_attn_f32_f16_q1(
}
}
} else if (tid == 0) {
#pragma unroll
FA_UNROLL
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
}
}
// Flash-decoding split pass. gid(2) = q_idx * n_splits + split_idx.
// Partial record per split: [m, l, O[DV]]. Merge kernel applies sink + norm.
#define FA_PARTIAL_FLOATS (2 + DV)
__kernel void flash_attn_f32_f16_q1_split(
const global void * q_void, ulong q_offset,
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
const float scale,
const int n_q,
const int n_kv,
const int n_head,
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
const float max_bias,
const float m0,
const float m1,
const int n_head_log2,
const float logit_softcap,
const int n_head_kv,
const global void * mask_void,
const ulong mask_offset,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3,
global float * partial_void,
const int n_splits,
const int kv_per_split
) {
const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1);
const int split_q_idx = get_global_id(2);
const int split_idx = split_q_idx % n_splits;
const int q_idx = split_q_idx / n_splits;
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const int gqa_ratio = n_head / n_head_kv;
const int head_kv_idx = head_idx / gqa_ratio;
const int kv_start = split_idx * kv_per_split;
const int kv_end = min(kv_start + kv_per_split, n_kv);
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx)
* n_splits + split_idx);
global float * rec = partial_void + record_idx * record_stride;
global float4 * rec_o = (global float4 *) (rec + 2);
if (kv_start >= kv_end) {
// Empty split: leave sentinel partial for merge.
if (tid == 0) {
rec[0] = FA_M_INIT;
rec[1] = 0.0f;
}
return;
}
const global char * q_base = (const global char *) q_void + q_offset;
const global char * k_base = (const global char *) k_void + k_offset;
const global char * v_base = (const global char *) v_void + v_offset;
const global char * mask_base = NULL;
if (mask_void != NULL) {
const int mask_head_idx = head_idx % mask_ne2;
const int mask_batch_idx = batch_idx % mask_ne3;
mask_base = (const global char *) mask_void + mask_offset +
mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 +
(ulong) q_idx * mask_nb1;
}
// Share Q via local memory (n_q=1 per split -> uniform across WG).
__local ACC_TYPE4 q_shared[DK_VEC];
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1;
const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset);
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
}
barrier(CLK_LOCAL_MEM_FENCE);
const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
// Pass 1a split-local max.
ACC_TYPE m_i = FA_M_INIT;
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; ++k) {
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
score += slope * (ACC_TYPE) mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
m_i = max(m_i, score);
}
__local ACC_TYPE local_m[Q1_WG_SIZE];
local_m[tid] = m_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE m_c = local_m[0];
// Pass 1b softmax-weighted V accumulate.
ACC_TYPE4 o_acc[DV_VEC];
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
ACC_TYPE l_i = 0.0f;
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
const global KV_DATA_TYPE4 * v_ptr = (const global KV_DATA_TYPE4 *) (v_base + v_row_offset);
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
#pragma unroll
for (int k = 0; k < DK_VEC; ++k) {
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
}
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
if (mask_base != NULL) {
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
score += slope * (ACC_TYPE) mask_ptr[k_idx];
}
if (logit_softcap > 0.0f) {
score = logit_softcap * tanh(score / logit_softcap);
}
const ACC_TYPE p = exp(score - m_c);
l_i += p;
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
}
}
__local ACC_TYPE local_l[Q1_WG_SIZE];
__local ACC_TYPE4 local_o[Q1_WG_SIZE];
local_l[tid] = l_i;
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_l[tid] += local_l[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
const ACC_TYPE l_c = local_l[0];
if (tid == 0) {
rec[0] = (float) m_c;
rec[1] = (float) l_c;
}
for (int i = 0; i < DV_VEC; ++i) {
local_o[tid] = o_acc[i];
barrier(CLK_LOCAL_MEM_FENCE);
#pragma unroll
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) local_o[tid] += local_o[tid + s];
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
rec_o[i] = local_o[0];
}
}
}
// FD Pass 2: merge per-split partials into final O. Empty splits drop via exp(-INF)=0.
__kernel void flash_attn_f32_merge(
const global float * partial_void,
global void * o_void,
const ulong o_offset,
const int n_head,
const int n_splits,
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
const global void * sinks_void,
const ulong sinks_offset,
const int n_q
) {
const int lane = get_local_id(0); // 0..DV_VEC-1
const int head_batch_idx = get_global_id(1);
const int q_idx = get_global_id(2);
const int batch_idx = head_batch_idx / n_head;
const int head_idx = head_batch_idx % n_head;
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits;
const global float * rec0 = partial_void + record_idx_0 * record_stride;
__local ACC_TYPE m_final_shared;
__local ACC_TYPE l_final_shared;
if (lane == 0) {
ACC_TYPE m = FA_M_INIT;
for (int c = 0; c < n_splits; ++c) {
const ACC_TYPE m_c = rec0[c * record_stride + 0];
m = max(m, m_c);
}
ACC_TYPE m_sink = 0.0f;
bool has_sink = false;
if (sinks_void != NULL) {
const global ACC_TYPE * sinks_ptr =
(const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset);
m_sink = sinks_ptr[head_idx];
has_sink = true;
m = max(m, m_sink);
}
ACC_TYPE l = 0.0f;
for (int c = 0; c < n_splits; ++c) {
const ACC_TYPE m_c = rec0[c * record_stride + 0];
const ACC_TYPE l_c = rec0[c * record_stride + 1];
if (m_c > FA_M_INIT) {
l += l_c * exp(m_c - m);
}
}
if (has_sink) {
l += exp(m_sink - m);
}
m_final_shared = m;
l_final_shared = l;
}
barrier(CLK_LOCAL_MEM_FENCE);
const ACC_TYPE m_final = m_final_shared;
const ACC_TYPE l_final = l_final_shared;
const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f;
ACC_TYPE4 o = (ACC_TYPE4)(0.0f);
for (int c = 0; c < n_splits; ++c) {
const global float * rec_c = rec0 + c * record_stride;
const ACC_TYPE m_c = rec_c[0];
if (m_c <= FA_M_INIT) continue;
const global float4 * rec_oc = (const global float4 *) (rec_c + 2);
const ACC_TYPE scale_c = exp(m_c - m_final);
o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o);
}
o = o * l_inv;
const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1;
global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset);
o_row[lane] = CONVERT_O_DATA4(o);
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,156 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void flash_attn_kv_pad_f16(
const global void * k_void, ulong k_offset,
const global void * v_void, ulong v_offset,
global void * k_pad_void,
global void * v_pad_void,
const int n_kv,
const int n_head_kv,
const int n_batch,
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3
) {
const int row_idx = get_global_id(0);
const int head_kv_idx = get_global_id(1);
const int batch_idx = get_global_id(2);
if (row_idx >= BLOCK_N || head_kv_idx >= n_head_kv || batch_idx >= n_batch) {
return;
}
const int tail_start = n_kv - (n_kv % BLOCK_N);
const int src_row_idx = tail_start + row_idx;
const global char * k_src = (const global char *) k_void + k_offset;
const global char * v_src = (const global char *) v_void + v_offset;
global char * k_pad = (global char *) k_pad_void;
global char * v_pad = (global char *) v_pad_void;
const ulong k_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * k_nb1) + (ulong) row_idx * k_nb1;
const ulong v_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * v_nb1) + (ulong) row_idx * v_nb1;
if (src_row_idx < n_kv) {
const ulong k_src_offset = (ulong) batch_idx * k_nb3 + (ulong) head_kv_idx * k_nb2 + (ulong) src_row_idx * k_nb1;
const ulong v_src_offset = (ulong) batch_idx * v_nb3 + (ulong) head_kv_idx * v_nb2 + (ulong) src_row_idx * v_nb1;
for (ulong i = 0; i < k_nb1; ++i) {
k_pad[k_dst_offset + i] = k_src[k_src_offset + i];
}
for (ulong i = 0; i < v_nb1; ++i) {
v_pad[v_dst_offset + i] = v_src[v_src_offset + i];
}
} else {
for (ulong i = 0; i < k_nb1; ++i) {
k_pad[k_dst_offset + i] = 0;
}
for (ulong i = 0; i < v_nb1; ++i) {
v_pad[v_dst_offset + i] = 0;
}
}
}
__kernel void flash_attn_mask_pad_f16(
const global void * mask_void, ulong mask_offset,
global void * mask_pad_void,
const int n_q,
const int n_kv,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int col_idx = get_global_id(0);
const int q_row = get_global_id(1);
const int mask_slice = get_global_id(2);
if (col_idx >= BLOCK_N || q_row >= n_q || mask_slice >= mask_ne2 * mask_ne3) {
return;
}
const int tail_start = n_kv - (n_kv % BLOCK_N);
const int src_col_idx = tail_start + col_idx;
const int mask_head_idx = mask_slice % mask_ne2;
const int mask_batch_idx = mask_slice / mask_ne2;
const global char * mask_src_base = (const global char *) mask_void + mask_offset +
(ulong) mask_batch_idx * mask_nb3 +
(ulong) mask_head_idx * mask_nb2 +
(ulong) q_row * mask_nb1;
const global half * mask_src = (const global half *) mask_src_base;
global half * mask_pad = (global half *) mask_pad_void;
const ulong dst_idx =
(((ulong) mask_batch_idx * (ulong) mask_ne2 + (ulong) mask_head_idx) * (ulong) n_q + (ulong) q_row) * (ulong) BLOCK_N +
(ulong) col_idx;
mask_pad[dst_idx] = src_col_idx < n_kv ? mask_src[src_col_idx] : (half) (-INFINITY);
}
// Per-KV-tile mask class. 0=all -inf (skip tile), 1=mixed (apply mask),
// 2=all zero, no -inf (skip mask lookup). Causal diagonal tiles are class 1.
__kernel void flash_attn_blk_f16(
const global void * mask_void, ulong mask_offset,
global char * blk,
const int n_q,
const int n_kv,
const ulong mask_nb1,
const ulong mask_nb2,
const ulong mask_nb3,
const int mask_ne2,
const int mask_ne3
) {
const int kv_block_idx = get_global_id(0);
const int q_block_idx = get_global_id(1);
const int mask_slice = get_global_id(2);
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
const int n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N;
if (kv_block_idx >= n_kv_blocks || q_block_idx >= n_q_blocks || mask_slice >= mask_ne2 * mask_ne3) {
return;
}
const int mask_head_idx = mask_slice % mask_ne2;
const int mask_batch_idx = mask_slice / mask_ne2;
const int q_start = q_block_idx * BLOCK_M;
const int k_start = kv_block_idx * BLOCK_N;
const int q_count = min(BLOCK_M, n_q - q_start);
const int k_count = min(BLOCK_N, n_kv - k_start);
const half neg_max_half = (half) (-65504.0f);
char has_unmasked = 0;
char has_masked = 0;
char has_nonzero = 0;
const global char * mask_base = (const global char *) mask_void + mask_offset +
(ulong) mask_batch_idx * mask_nb3 +
(ulong) mask_head_idx * mask_nb2;
for (int qi = 0; qi < q_count; ++qi) {
const global half * mask_row = (const global half *) (mask_base + (ulong) (q_start + qi) * mask_nb1) + k_start;
for (int ki = 0; ki < k_count; ++ki) {
const half v = mask_row[ki];
if (v <= neg_max_half) {
has_masked = 1;
} else {
has_unmasked = 1;
if (v != (half) 0.0f) {
has_nonzero = 1;
}
}
}
if (has_masked && has_unmasked) break; // mixed tile — short-circuit.
}
char res;
if (has_unmasked == 0) {
res = 0;
} else if (has_masked || has_nonzero) {
res = 1;
} else {
res = 2;
}
blk[((ulong) mask_slice * (ulong) n_q_blocks + (ulong) q_block_idx) * (ulong) n_kv_blocks + (ulong) kv_block_idx] = res;
}
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
}
// reduction in local memory, assumes #wave=4
+5 -2
View File
@@ -24,6 +24,7 @@ kernel void kernel_norm(
int ne01,
int ne02,
int ne03,
ulong nb00,
ulong nb01,
ulong nb02,
ulong nb03,
@@ -43,7 +44,8 @@ kernel void kernel_norm(
// parallel sum
sum[get_local_id(0)] = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
sum[get_local_id(0)] += x[i00];
// this kernel handles float, nb00/4 translates byte offset to element offset
sum[get_local_id(0)] += x[i00*nb00/4];
}
// reduce
barrier(CLK_LOCAL_MEM_FENCE);
@@ -60,7 +62,8 @@ kernel void kernel_norm(
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
sum[get_local_id(0)] = 0.0f;
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
y[i00] = x[i00] - mean;
// this kernel handles float, nb00/4 translates byte offset to element offset
y[i00] = x[i00*nb00/4] - mean;
sum[get_local_id(0)] += y[i00] * y[i00];
}
+500
View File
@@ -158,6 +158,239 @@ kernel void kernel_set_rows_f32_i32(
}
}
// f32 -> q8_0 quantize set_rows. Block = half d + char qs[32].
#define QK8_0 32
inline void quantize_q8_0_block(global float * x, global char * qs, global half * d_out) {
float amax = 0.0f;
for (int j = 0; j < QK8_0; j++) {
amax = fmax(amax, fabs(x[j]));
}
float d = amax / 127.0f;
float id = (d != 0.0f) ? 127.0f / amax : 0.0f;
vstore_half(d, 0, d_out);
for (int j = 0; j < QK8_0; j++) {
qs[j] = (char)((int)round(x[j] * id));
}
}
kernel void kernel_set_rows_q8_0_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * y = dst_row + blk * (2 + QK8_0);
quantize_q8_0_block(x, y + 2, (global half *)y);
}
}
kernel void kernel_set_rows_q8_0_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * y = dst_row + blk * (2 + QK8_0);
quantize_q8_0_block(x, y + 2, (global half *)y);
}
}
// SoA q8_0 variants. dst_q: int8[QK8_0] per block; dst_d: fp16 scale per block.
// Layout matches kernel_convert_block_q8_0; block index follows dst element order.
kernel void kernel_set_rows_q8_0_soa_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * q = q_row + blk * QK8_0;
quantize_q8_0_block(x, q, d_row + blk);
}
}
kernel void kernel_set_rows_q8_0_soa_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK8_0;
global char * q = q_row + blk * QK8_0;
quantize_q8_0_block(x, q, d_row + blk);
}
}
kernel void kernel_set_rows_f16_i32(
global char * src0,
ulong offset0,
@@ -206,3 +439,270 @@ kernel void kernel_set_rows_f16_i32(
dst_row[ind] = src_row[ind];
}
}
// f32 -> q4_0 quantize set_rows. Block = half d + uchar qs[16] (shuffled
// nibbles: qs[j] low/high = elem j / j+16).
// Dequant: val[i] = d * (nibble_i - 8)
// nblk0 = number of q4_0 blocks per row = ne00 / 32.
#define QK4_0 32
#define Q4_0_BLOCK_SIZE 18
inline void quantize_q4_0_block(global float * x, global uchar * qs, global half * d_out) {
// Find the signed value with the largest absolute magnitude (matches ggml ref).
float max = 0.0f;
float amax = 0.0f;
for (int j = 0; j < QK4_0; j++) {
float v = x[j];
float a = fabs(v);
if (a > amax) {
amax = a;
max = v;
}
}
float d = max / -8.0f;
float id = (d != 0.0f) ? 1.0f / d : 0.0f;
vstore_half(d, 0, d_out);
for (int j = 0; j < QK4_0/2; j++) {
float x0 = x[j] * id;
float x1 = x[j + QK4_0/2] * id;
int i0 = (int)(x0 + 8.5f);
int i1 = (int)(x1 + 8.5f);
if (i0 < 0) i0 = 0;
if (i0 > 15) i0 = 15;
if (i1 < 0) i1 = 0;
if (i1 > 15) i1 = 15;
qs[j] = (uchar)i0 | ((uchar)i1 << 4);
}
}
kernel void kernel_set_rows_q4_0_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
global half * yd = (global half *)(y);
global uchar * yqs = (global uchar *)(y + 2);
quantize_q4_0_block(x, yqs, yd);
}
}
kernel void kernel_set_rows_q4_0_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
global half * yd = (global half *)(y);
global uchar * yqs = (global uchar *)(y + 2);
quantize_q4_0_block(x, yqs, yd);
}
}
// SoA variants for q4_0 dst. Used when the backend has split block_q4_0 records
// into separate quant (dst_q) and scale (dst_d) sub-buffers same pattern as
// the q8_0 SoA variants above.
//
// Layout (matches kernel_convert_block_q4_0, the "shuffled" variant):
// dst_q: contiguous 16 packed nibbles per block, block i at offset i * 16 bytes.
// dst_d: contiguous fp16 scales, block i at offset i * 2 bytes.
// Nibble layout inside each byte is unchanged from AoS: qs[j] low nibble = element j,
// qs[j] high nibble = element j+16. kernel_restore_block_q4_0 copies bytes as-is.
kernel void kernel_set_rows_q4_0_soa_i64(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global uchar * qs = q_row + blk * (QK4_0/2);
global half * d_bk = d_row + blk;
quantize_q4_0_block(x, qs, d_bk);
}
}
kernel void kernel_set_rows_q4_0_soa_i32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst_q,
ulong offset_q,
global char * dst_d,
ulong offset_d,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
uint4 ne11,
uint4 ne12,
ulong nb10,
ulong nb11,
ulong nb12,
int nblk0,
int ne1_dst,
int ne2_dst,
int ne3_dst
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst_q = dst_q + offset_q;
dst_d = dst_d + offset_d;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
if (i01 >= ne01) {
return;
}
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
global half * d_row = (global half *)(dst_d) + row_blk_base;
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
global float * x = src_row + blk * QK4_0;
global uchar * qs = q_row + blk * (QK4_0/2);
global half * d_bk = d_row + blk;
quantize_q4_0_block(x, qs, d_bk);
}
}
+3 -66
View File
@@ -1270,77 +1270,14 @@ void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecode
}
std::string GgmlOvDecoder::compute_op_type(const ggml_tensor * node) {
static const std::map<ggml_op, std::string> ops = {
{GGML_OP_NONE, "GGML_OP_NONE" },
{GGML_OP_ACC, "GGML_OP_ACC" },
{GGML_OP_ADD, "GGML_OP_ADD" },
{GGML_OP_ADD1, "GGML_OP_ADD1" },
{GGML_OP_ADD_ID, "GGML_OP_ADD_ID" },
{GGML_OP_CONCAT, "GGML_OP_CONCAT" },
{GGML_OP_CONT, "GGML_OP_CONT" },
{GGML_OP_DIV, "GGML_OP_DIV" },
{GGML_OP_DUP, "GGML_OP_DUP" },
{GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS" },
{GGML_OP_MUL, "GGML_OP_MUL" },
{GGML_OP_MUL_MAT, "GGML_OP_MUL_MAT" },
{GGML_OP_MUL_MAT_ID, "GGML_OP_MUL_MAT_ID" },
{GGML_OP_PERMUTE, "GGML_OP_PERMUTE" },
{GGML_OP_RESHAPE, "GGML_OP_RESHAPE" },
{GGML_OP_RMS_NORM, "GGML_OP_RMS_NORM" },
{GGML_OP_NORM, "GGML_OP_NORM" },
{GGML_OP_ROPE, "GGML_OP_ROPE" },
{GGML_OP_SCALE, "GGML_OP_SCALE" },
{GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" },
{GGML_OP_SUM_ROWS, "GGML_OP_SUM_ROWS" },
{GGML_OP_SUB, "GGML_OP_SUB" },
{GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE" },
{GGML_OP_VIEW, "GGML_OP_VIEW" },
{GGML_OP_SET_ROWS, "GGML_OP_SET_ROWS" },
{GGML_OP_CPY, "GGML_OP_CPY" },
{GGML_OP_FLASH_ATTN_EXT, "GGML_OP_FLASH_ATTN_EXT" },
{GGML_OP_L2_NORM, "GGML_OP_L2_NORM" },
{GGML_OP_CLAMP, "GGML_OP_CLAMP" },
{GGML_OP_PAD, "GGML_OP_PAD" },
{GGML_OP_SSM_CONV, "GGML_OP_SSM_CONV" },
{GGML_OP_GATED_DELTA_NET, "GGML_OP_GATED_DELTA_NET"},
{GGML_OP_ARGSORT, "GGML_OP_ARGSORT" },
{GGML_OP_REPEAT, "GGML_OP_REPEAT" },
{GGML_OP_IM2COL, "GGML_OP_IM2COL" }
};
static const std::map<ggml_unary_op, std::string> unary_ops = {
{GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS" },
{GGML_UNARY_OP_SGN, "GGML_UNARY_OP_SGN" },
{GGML_UNARY_OP_NEG, "GGML_UNARY_OP_NEG" },
{GGML_UNARY_OP_STEP, "GGML_UNARY_OP_STEP" },
{GGML_UNARY_OP_TANH, "GGML_UNARY_OP_TANH" },
{GGML_UNARY_OP_ELU, "GGML_UNARY_OP_ELU" },
{GGML_UNARY_OP_RELU, "GGML_UNARY_OP_RELU" },
{GGML_UNARY_OP_SIGMOID, "GGML_UNARY_OP_SIGMOID" },
{GGML_UNARY_OP_GELU, "GGML_UNARY_OP_GELU" },
{GGML_UNARY_OP_GELU_QUICK, "GGML_UNARY_OP_GELU_QUICK" },
{GGML_UNARY_OP_SILU, "GGML_UNARY_OP_SILU" },
{GGML_UNARY_OP_SOFTPLUS, "GGML_UNARY_OP_SOFTPLUS" },
{GGML_UNARY_OP_HARDSWISH, "GGML_UNARY_OP_HARDSWISH" },
{GGML_UNARY_OP_HARDSIGMOID, "GGML_UNARY_OP_HARDSIGMOID"},
{GGML_UNARY_OP_EXP, "GGML_UNARY_OP_EXP" },
{GGML_UNARY_OP_COUNT, "GGML_UNARY_OP_COUNT" }
};
static const std::map<ggml_glu_op, std::string> glu_ops = {
{GGML_GLU_OP_SWIGLU, "GGML_GLU_OP_SWIGLU"},
{GGML_GLU_OP_GEGLU, "GGML_GLU_OP_GEGLU" },
{GGML_GLU_OP_REGLU, "GGML_GLU_OP_REGLU" }
};
switch (node->op) {
case GGML_OP_UNARY:
return unary_ops.at(ggml_get_unary_op(node));
return std::string("GGML_UNARY_OP_") + ggml_unary_op_name(ggml_get_unary_op(node));
case GGML_OP_GLU:
return glu_ops.at(ggml_get_glu_op(node));
return std::string("GGML_GLU_OP_") + ggml_glu_op_name(ggml_get_glu_op(node));
default:
return ops.at(node->op);
return std::string("GGML_OP_") + ggml_op_name(node->op);
}
static const std::string unknown_op = "UNKNOWN_GGML_OP";
return unknown_op;
}
const std::string & GgmlOvDecoder::get_op_type(int node_idx) const {
+4
View File
@@ -1053,6 +1053,10 @@ static bool is_op_unsupported_case(const ggml_tensor * op) {
(op->ne[0] == 2 && op->ne[1] == 4 && op->ne[2] == 3 && op->ne[3] == 2)) {
return true;
}
// CPY into a strided view of a larger buffer (recurrent-state snapshots) not supported
if (op->view_src && ggml_nbytes(op) != ggml_nbytes(op->view_src)) {
return true;
}
break;
}
case GGML_OP_MUL_MAT: {
+19 -5
View File
@@ -17,6 +17,22 @@ namespace frontend {
namespace ggml {
namespace op {
static ov::Output<ov::Node> reshape_add_id_input_to_2d(const ov::Output<ov::Node> & input,
const ov::PartialShape & input_shape,
const std::vector<int> & dims) {
const auto actual_shape = input.get_partial_shape();
if (actual_shape.rank().is_static() && actual_shape.rank().get_length() == 2) {
return input;
}
if (input_shape.rank().is_static() && input_shape.rank().get_length() == 2) {
return input;
}
auto shape = std::make_shared<ov::op::v3::ShapeOf>(input, ov::element::i64);
return std::make_shared<ov::op::v1::Reshape>(input, get_dimensions(shape, dims), false);
}
OutputVector translate_add_id(const NodeContext & context) {
num_inputs_check(context, 3, 3);
@@ -28,11 +44,9 @@ OutputVector translate_add_id(const NodeContext & context) {
// input: [1, n_token, n_used, n_embd]
// bias: [1, 1, n_expert, n_embd]
// ids: [1, 1, n_token, n_used]
auto bias_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(bias, ov::element::i64);
auto ids_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(ids, ov::element::i64);
bias = std::make_shared<ov::op::v1::Reshape>(bias, get_dimensions(bias_shape_4d, {2, 3}), false);
ids = std::make_shared<ov::op::v1::Reshape>(ids, get_dimensions(ids_shape_4d, {2, 3}), false);
// Model bias constants may already be stored as [n_expert, n_embd].
bias = reshape_add_id_input_to_2d(bias, context.get_input_shape(1), {2, 3});
ids = reshape_add_id_input_to_2d(ids, context.get_input_shape(2), {2, 3});
if (ids.get_element_type() != ov::element::i32 && ids.get_element_type() != ov::element::i64) {
ids = std::make_shared<ov::op::v0::Convert>(ids, ov::element::i32);
@@ -3,8 +3,11 @@
#include "../utils.h"
#include <cstdint>
#include <limits>
#include <memory>
#include <openvino/core/node_output.hpp>
#include <openvino/op/add.hpp>
#include <openvino/op/clamp.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/sigmoid.hpp>
@@ -15,7 +18,7 @@ namespace frontend {
namespace ggml {
namespace op {
OutputVector translate_glu_swiglu(const NodeContext & context) {
static std::pair<ov::Output<ov::Node>, ov::Output<ov::Node>> get_glu_inputs(const NodeContext & context) {
num_inputs_check(context, 1, 2);
ov::Output<ov::Node> src0;
@@ -52,6 +55,12 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
std::swap(src0, src1);
}
return {src0, src1};
}
OutputVector translate_glu_swiglu(const NodeContext & context) {
auto [src0, src1] = get_glu_inputs(context);
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0);
auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid);
auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1);
@@ -59,6 +68,27 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
return rename_outputs_with_suffix({res}, context.get_name());
}
OutputVector translate_glu_swiglu_oai(const NodeContext & context) {
auto [src0, src1] = get_glu_inputs(context);
const int32_t * params = context.get_output_op_params();
const float alpha = reinterpret_cast<const float *>(params)[2];
const float limit = reinterpret_cast<const float *>(params)[3];
auto gate = std::make_shared<ov::op::v0::Clamp>(src0, -std::numeric_limits<float>::infinity(), limit);
auto alpha_const = ov::op::v0::Constant::create(ov::element::f32, {}, {alpha});
auto scaled_gate = std::make_shared<ov::op::v1::Multiply>(gate, alpha_const);
auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(scaled_gate);
auto out_glu = std::make_shared<ov::op::v1::Multiply>(gate, sigmoid);
auto up = std::make_shared<ov::op::v0::Clamp>(src1, -limit, limit);
auto one = ov::op::v0::Constant::create(ov::element::f32, {}, {1.0f});
auto up_plus_one = std::make_shared<ov::op::v1::Add>(up, one);
auto res = std::make_shared<ov::op::v1::Multiply>(out_glu, up_plus_one);
return rename_outputs_with_suffix({res}, context.get_name());
}
} // namespace op
} // namespace ggml
} // namespace frontend
@@ -2,23 +2,135 @@
#include "../op_table.h"
#include "../utils.h"
#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <openvino/op/bitwise_and.hpp>
#include <openvino/op/bitwise_right_shift.hpp>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/squeeze.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <vector>
namespace ov {
namespace frontend {
namespace ggml {
namespace op {
namespace {
std::shared_ptr<ov::op::v0::Constant> const_i64(const std::vector<int64_t> & values) {
return ov::op::v0::Constant::create(ov::element::i64, ov::Shape{values.size()}, values);
}
ov::Output<ov::Node> slice_axis(const ov::Output<ov::Node> & input, int64_t axis, int64_t begin, int64_t end) {
return std::make_shared<ov::op::v8::Slice>(input, const_i64({begin}), const_i64({end}), const_i64({1}),
const_i64({axis}));
}
ov::Output<ov::Node> translate_mul_mat_id_mxfp4_packed(const NodeContext & context,
ov::Output<ov::Node> expert_weights,
ov::Output<ov::Node> activations,
ov::Output<ov::Node> ids) {
auto packed_shape = expert_weights.get_partial_shape().to_shape();
FRONT_END_OP_CONVERSION_CHECK(packed_shape.size() == 5 && packed_shape[4] == 17,
"Expected packed MXFP4 expert weights with shape [1, n_expert, m, k_blocks, 17]");
const int64_t n_expert = static_cast<int64_t>(packed_shape[1]);
const int64_t rows = static_cast<int64_t>(packed_shape[2]);
const int64_t k_blocks = static_cast<int64_t>(packed_shape[3]);
const int64_t qk = 32;
const int64_t cols = k_blocks * qk;
auto packed_shape_4d = const_i64({n_expert, rows, k_blocks, 17});
expert_weights = std::make_shared<ov::op::v1::Reshape>(expert_weights, packed_shape_4d, false);
auto activations_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(activations, ov::element::i64);
auto ids_shape_4d = std::make_shared<ov::op::v3::ShapeOf>(ids, ov::element::i64);
auto activations_shape_3d = get_dimensions(activations_shape_4d, {1, 2, 3});
auto ids_shape_2d = get_dimensions(ids_shape_4d, {2, 3});
activations = std::make_shared<ov::op::v1::Reshape>(activations, activations_shape_3d, false);
ids = std::make_shared<ov::op::v1::Reshape>(ids, ids_shape_2d, false);
if (ids.get_element_type() != ov::element::i32 && ids.get_element_type() != ov::element::i64) {
ids = std::make_shared<ov::op::v0::Convert>(ids, ov::element::i32);
}
auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
static const std::vector<float> f4e2m1_lut = {0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f,
-0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f};
std::vector<float> e8m0_lut(256);
for (size_t i = 0; i < e8m0_lut.size(); ++i) {
uint32_t bits = static_cast<uint32_t>(i) << 23;
memcpy(&e8m0_lut[i], &bits, sizeof(float));
}
e8m0_lut[0] = std::numeric_limits<float>::min() / 2.0f;
e8m0_lut[255] = std::numeric_limits<float>::quiet_NaN();
auto f4_lut = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{f4e2m1_lut.size()}, f4e2m1_lut);
auto scale_lut = ov::op::v0::Constant::create(ov::element::f32, ov::Shape{e8m0_lut.size()}, e8m0_lut);
auto selected_packed_weights = std::make_shared<ov::op::v8::Gather>(expert_weights, ids, gather_axis);
auto scale_byte = slice_axis(selected_packed_weights, 4, 0, 1);
auto qs = slice_axis(selected_packed_weights, 4, 1, 17);
auto low = std::make_shared<ov::op::v13::BitwiseAnd>(
qs, ov::op::v0::Constant::create(ov::element::u8, ov::Shape{}, {0x0F}), ov::op::AutoBroadcastType::NUMPY);
auto high_shift = std::make_shared<ov::op::v15::BitwiseRightShift>(
qs, ov::op::v0::Constant::create(ov::element::u8, ov::Shape{}, {4}), ov::op::AutoBroadcastType::NUMPY);
auto nibbles = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{low, high_shift}, 4);
auto nibble_indices = std::make_shared<ov::op::v0::Convert>(nibbles, ov::element::i32);
auto weights_f32 = std::make_shared<ov::op::v8::Gather>(f4_lut, nibble_indices, gather_axis);
auto scale_indices = std::make_shared<ov::op::v0::Convert>(scale_byte, ov::element::i32);
auto scales_f32 = std::make_shared<ov::op::v8::Gather>(scale_lut, scale_indices, gather_axis);
ov::Output<ov::Node> selected_weights = std::make_shared<ov::op::v1::Multiply>(weights_f32, scales_f32,
ov::op::AutoBroadcastType::NUMPY);
auto ids_shape = std::make_shared<ov::op::v3::ShapeOf>(ids, ov::element::i64);
auto selected_weights_target_dims = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{get_dimensions(ids_shape, {0, 1}), const_i64({rows, cols})}, 0);
selected_weights = std::make_shared<ov::op::v1::Reshape>(selected_weights, selected_weights_target_dims, false);
auto activations_shape = std::make_shared<ov::op::v3::ShapeOf>(activations, ov::element::i64);
ov::Output<ov::Node> acts_target_dims = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{
get_dimensions(activations_shape, {0}),
get_dimensions(ids_shape, {1}),
get_dimensions(activations_shape, {2}),
},
0);
ov::Output<ov::Node> acts_broadcasted =
std::make_shared<ov::op::v3::Broadcast>(activations, acts_target_dims, ov::op::BroadcastType::BIDIRECTIONAL);
auto activations_expanded = std::make_shared<ov::op::v0::Unsqueeze>(acts_broadcasted, const_i64({2}));
ov::Output<ov::Node> result =
std::make_shared<ov::op::v0::MatMul>(activations_expanded, selected_weights, false, true);
auto batch_dim = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto row_dim = ov::op::v0::Constant::create(ov::element::i64, {1}, {rows});
auto result_target_dims = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{batch_dim, get_dimensions(ids_shape, {0, 1}), row_dim}, 0);
result = std::make_shared<ov::op::v1::Reshape>(result, result_target_dims, false);
const auto output_type = context.get_output_type();
if (result.get_element_type() != output_type) {
result = std::make_shared<ov::op::v0::Convert>(result, output_type);
}
return result;
}
} // namespace
OutputVector translate_mul_mat_id(const NodeContext & context) {
num_inputs_check(context, 3, 3);
@@ -26,6 +138,12 @@ OutputVector translate_mul_mat_id(const NodeContext & context) {
auto activations = process_view_input_new(context, 1);
auto ids = process_view_input_new(context, 2);
if (expert_weights.get_element_type() == ov::element::u8 && expert_weights.get_partial_shape().rank().is_static() &&
expert_weights.get_partial_shape().rank().get_length() == 5) {
return rename_outputs_with_suffix({translate_mul_mat_id_mxfp4_packed(context, expert_weights, activations, ids)},
context.get_name());
}
// OpenVINO sees GGML tensors in reversed dimension order:
// weights: [1, n_expert, m, k]
// activations: [1, n_tokens, n_used_or_1, k]

Some files were not shown because too many files have changed in this diff Show More