mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-28 18:23:03 +02:00
Compare commits
77 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 27c8bb4f63 | |||
| ebd048fc5e | |||
| 0ed235ea2c | |||
| 9bebfcb4bc | |||
| 0b6529d818 | |||
| c299a92c38 | |||
| 0275c0f800 | |||
| 83d385b429 | |||
| 050ee92d04 | |||
| 3fc4e10527 | |||
| 5d8ccdf9d1 | |||
| 024930c6ad | |||
| 5397c36194 | |||
| e7ea94afcb | |||
| 96183e9820 | |||
| 487a6cc164 | |||
| 5a6a0dd7e1 | |||
| ded1561b42 | |||
| 9df06805ee | |||
| 2f18fe13c5 | |||
| c16c35b814 | |||
| 1a87dcdc45 | |||
| e7e3f35090 | |||
| b11f7c16bc | |||
| f818065d75 | |||
| 960d628f46 | |||
| 5c7c22c3e1 | |||
| beac5309f1 | |||
| 9d5d882d8c | |||
| 1ec44d178d | |||
| c7cddefcbd | |||
| e9d1b76d0a | |||
| 099bf06952 | |||
| 60bc8866b1 | |||
| e8ecce53b8 | |||
| 683b04cc4a | |||
| f728adab68 | |||
| 3e61ea0e2f | |||
| fdbd6abee2 | |||
| e12a0128ab | |||
| b3ce5cedf4 | |||
| e9fb3b3fc0 | |||
| 9c10954865 | |||
| fdb2c11c70 | |||
| 09cedfd699 | |||
| 8be759e6f7 | |||
| 894bb27af3 | |||
| fb401045cc | |||
| 51eae8cfca | |||
| 1191758c5d | |||
| 00139b660b | |||
| ef9c13d4c2 | |||
| 88636e178f | |||
| ac4105d68b | |||
| be4a6a63eb | |||
| 72a9269172 | |||
| 92e854ab83 | |||
| c5606364b2 | |||
| 0eb874d374 | |||
| 75ad0b23ed | |||
| c926ad0985 | |||
| a3900a6694 | |||
| 7c908502ea | |||
| 035cd8f9a6 | |||
| 73618f27a8 | |||
| 23ee8797e1 | |||
| dec5ca5577 | |||
| 9c0ac887f3 | |||
| 721354fbdf | |||
| 6ee0f65793 | |||
| 099b579acb | |||
| f8cc15f163 | |||
| 37957e8531 | |||
| d0f9d2e5ac | |||
| 0ef6f06d55 | |||
| 52b3df0023 | |||
| 7c082bc417 |
@@ -145,7 +145,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
# ==============================================================================
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
ENTRYPOINT [ "/app/llama-cli" ]
|
||||
|
||||
@@ -156,7 +156,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
HEALTHCHECK --interval=5m CMD [ "curl", "-f", "http://localhost:8080/health" ]
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -115,7 +115,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -124,7 +124,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -153,7 +153,7 @@ FROM base AS server
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -126,7 +126,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
ARG OPENVINO_VERSION_MAJOR=2026.2
|
||||
ARG OPENVINO_VERSION_FULL=2026.2.0.21903.52ddc073857
|
||||
ARG OPENVINO_VERSION_MAJOR=2026.2.1
|
||||
ARG OPENVINO_VERSION_FULL=2026.2.1.21919.ede283a88e3
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# Intel GPU driver versions. https://github.com/intel/compute-runtime/releases
|
||||
ARG IGC_VERSION=v2.34.4
|
||||
ARG IGC_VERSION_FULL=2_2.34.4+21428
|
||||
ARG COMPUTE_RUNTIME_VERSION=26.18.38308.1
|
||||
ARG COMPUTE_RUNTIME_VERSION_FULL=26.18.38308.1-0
|
||||
ARG IGC_VERSION=v2.36.3
|
||||
ARG IGC_VERSION_FULL=2_2.36.3+21719
|
||||
ARG COMPUTE_RUNTIME_VERSION=26.22.38646.4
|
||||
ARG COMPUTE_RUNTIME_VERSION_FULL=26.22.38646.4-0
|
||||
ARG IGDGMM_VERSION=22.10.0
|
||||
|
||||
# Intel NPU driver versions. https://github.com/intel/linux-npu-driver/releases
|
||||
@@ -214,7 +214,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app/
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -225,7 +225,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app/
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -138,7 +138,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ WORKDIR /llama.cpp/bin
|
||||
|
||||
# Copy llama.cpp binaries and libraries
|
||||
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin/llama-completion /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama /llama.cpp/bin/llama-cli /llama.cpp/bin/llama-completion /llama.cpp/bin
|
||||
|
||||
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
|
||||
|
||||
@@ -138,7 +138,7 @@ WORKDIR /llama.cpp/bin
|
||||
|
||||
# Copy llama.cpp binaries and libraries
|
||||
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama /llama.cpp/bin/llama-server /llama.cpp/bin
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -118,7 +118,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -108,7 +108,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
+26
-19
@@ -35,8 +35,20 @@ AMD ZenDNN:
|
||||
documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "**/*.md"
|
||||
- docs/**
|
||||
- media/**
|
||||
examples:
|
||||
- all:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- app/**
|
||||
- examples/**
|
||||
- tools/**
|
||||
- all-globs-to-all-files:
|
||||
- '!tools/server/**'
|
||||
- '!tools/mtmd/**'
|
||||
- '!tools/ui/**'
|
||||
testing:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
@@ -47,28 +59,12 @@ build:
|
||||
- cmake/**
|
||||
- CMakeLists.txt
|
||||
- CMakePresets.json
|
||||
examples:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- examples/**
|
||||
- tools/**
|
||||
devops:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- .devops/**
|
||||
- .github/**
|
||||
- ci/**
|
||||
python:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- "**/*.py"
|
||||
- requirements/**
|
||||
- gguf-py/**
|
||||
- .flake8
|
||||
script:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- scripts/**
|
||||
android:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
@@ -81,9 +77,20 @@ server:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- tools/server/**
|
||||
|
||||
|
||||
|
||||
mtmd:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- tools/mtmd/**
|
||||
conversion:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- conversion/**
|
||||
- convert_*.py
|
||||
- gguf-py/**
|
||||
vendor:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- vendor/**
|
||||
ggml:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
|
||||
@@ -68,8 +68,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -96,8 +96,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -39,8 +39,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -96,8 +96,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -266,8 +266,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -446,8 +446,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Set OpenVINO version output
|
||||
@@ -506,8 +506,11 @@ jobs:
|
||||
cmake -B build/ReleaseOV -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_OPENVINO=ON \
|
||||
-DHF_UI_VERSION=${{ needs.get-version.outputs.ui_version }}
|
||||
cmake --build build/ReleaseOV --config Release -j $(nproc)
|
||||
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
|
||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||
-DHF_UI_VERSION=${{ needs.get-version.outputs.ui_version }} \
|
||||
${{ env.CMAKE_ARGS }}
|
||||
cmake --build build/ReleaseOV --config Release --parallel
|
||||
|
||||
- name: ccache-clear
|
||||
uses: ./.github/actions/ccache-clear
|
||||
@@ -521,8 +524,26 @@ jobs:
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/ReleaseOV/bin/
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz --transform "s,^\.,llama-${{ steps.tag.outputs.name }}," -C ./build/ReleaseOV/bin .
|
||||
dest=./build/ReleaseOV/bin
|
||||
OPENVINO_ROOT=./openvino_toolkit
|
||||
ov_lib="$OPENVINO_ROOT/runtime/lib/intel64"
|
||||
|
||||
# Bundle OpenVINO runtime libs + TBB. Binaries built with RPATH=$ORIGIN
|
||||
# load these siblings without setupvars.sh / LD_LIBRARY_PATH.
|
||||
cp -P "$ov_lib"/libopenvino.so* \
|
||||
"$ov_lib"/libopenvino_c.so* \
|
||||
"$ov_lib"/libopenvino_*_plugin.so \
|
||||
"$ov_lib"/libopenvino_intel_npu_compiler*.so \
|
||||
"$OPENVINO_ROOT"/runtime/3rdparty/tbb/lib/*.so* \
|
||||
"$dest"
|
||||
cp -P /usr/lib/x86_64-linux-gnu/libOpenCL.so.1* "$dest" 2>/dev/null || true
|
||||
cp "$ov_lib"/cache.json "$dest" 2>/dev/null || true
|
||||
|
||||
# OpenVINO licensing
|
||||
cp -r "$OPENVINO_ROOT"/docs/licensing "$dest"/openvino-licensing
|
||||
|
||||
cp LICENSE "$dest"
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz --transform "s,^\.,llama-${{ steps.tag.outputs.name }}," -C "$dest" .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
@@ -531,6 +552,9 @@ jobs:
|
||||
name: llama-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz
|
||||
|
||||
windows-openvino:
|
||||
needs: [check-release]
|
||||
if: ${{ needs.check-release.outputs.should_release == 'true' }}
|
||||
|
||||
runs-on: windows-2022
|
||||
|
||||
outputs:
|
||||
@@ -538,8 +562,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Set OpenVINO version output
|
||||
@@ -607,7 +631,9 @@ jobs:
|
||||
-A x64 ^
|
||||
-DCMAKE_BUILD_TYPE=Release ^
|
||||
-DGGML_OPENVINO=ON ^
|
||||
-DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake
|
||||
-DLLAMA_BUILD_BORINGSSL=ON ^
|
||||
-DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake ^
|
||||
${{ env.CMAKE_ARGS }}
|
||||
|
||||
cmake --build build\ReleaseOV --config Release -- /m
|
||||
|
||||
@@ -624,8 +650,29 @@ jobs:
|
||||
id: pack_artifacts
|
||||
shell: powershell
|
||||
run: |
|
||||
Copy-Item LICENSE .\build\ReleaseOV\bin\
|
||||
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.zip .\build\ReleaseOV\bin\*
|
||||
# Locate the extracted OpenVINO toolkit root (same pattern as the Build step).
|
||||
$OPENVINO_ROOT = (Get-ChildItem -Directory openvino_toolkit | Select-Object -First 1).FullName
|
||||
if (-not $OPENVINO_ROOT) {
|
||||
Write-Error "OpenVINO toolkit folder not found under .\openvino_toolkit"
|
||||
exit 1
|
||||
}
|
||||
|
||||
$dest = ".\build\ReleaseOV\bin\Release"
|
||||
|
||||
$ovBin = Join-Path $OPENVINO_ROOT 'runtime\bin\intel64\Release'
|
||||
Copy-Item -Path (Join-Path $ovBin '*.dll') -Destination $dest -Force
|
||||
Copy-Item -Path (Join-Path $ovBin 'cache.json') -Destination $dest -Force
|
||||
|
||||
$tbbBin = Join-Path $OPENVINO_ROOT 'runtime\3rdparty\tbb\bin'
|
||||
Copy-Item -Path (Join-Path $tbbBin 'tbb*.dll') -Destination $dest -Force
|
||||
|
||||
# OpenVINO licensing
|
||||
$licensingDest = Join-Path $dest 'openvino-licensing'
|
||||
New-Item -ItemType Directory -Force -Path $licensingDest | Out-Null
|
||||
Copy-Item -Path (Join-Path $OPENVINO_ROOT 'docs\licensing\*') -Destination $licensingDest -Recurse -Force
|
||||
|
||||
Copy-Item LICENSE $dest
|
||||
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.zip $dest\*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
|
||||
@@ -222,6 +222,16 @@ if (LLAMA_BUILD_APP)
|
||||
add_subdirectory(app)
|
||||
endif()
|
||||
|
||||
# Standalone libmtmd build without pulling in the rest of the tools/ tree.
|
||||
# Useful when packaging just the mtmd library for language bindings (e.g. an
|
||||
# Apple XCFramework, or a WASM build). When the full tools build is enabled,
|
||||
# mtmd is already built by the tools/ subdirectory above; this hook only fires
|
||||
# when LLAMA_BUILD_TOOLS is OFF to avoid double-adding the target.
|
||||
option(LLAMA_BUILD_MTMD "llama: build tools/mtmd library standalone" OFF)
|
||||
if (LLAMA_BUILD_MTMD AND NOT (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TOOLS))
|
||||
add_subdirectory(tools/mtmd)
|
||||
endif()
|
||||
|
||||
#
|
||||
# install
|
||||
#
|
||||
|
||||
+1
-1
@@ -10,7 +10,7 @@
|
||||
# ggml-org/ggml-rpc : rgerganov
|
||||
# ggml-org/ggml-sycl : arthw
|
||||
# ggml-org/ggml-vulkan : 0cc4m, jeffbolznv
|
||||
# ggml-org/ggml-webgpu : reeselevine
|
||||
# ggml-org/ggml-webgpu : reeselevine, yomaytk
|
||||
# ggml-org/ggml-zdnn : taronaeo
|
||||
# ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin
|
||||
# ggml-org/llama-mtmd : ngxson
|
||||
|
||||
@@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
|
||||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
|
||||
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
|
||||
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
|
||||
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
|
||||
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
|
||||
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
|
||||
|
||||
+1
-1
@@ -80,7 +80,7 @@ To protect sensitive data from potential leaks or unauthorized access, it is cru
|
||||
### Untrusted environments or networks
|
||||
|
||||
If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:
|
||||
* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
|
||||
* Do not use the RPC backend, [ggml-rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
|
||||
* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value.
|
||||
* Encrypt your data if sending it over the network.
|
||||
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
set(TARGET llama-app)
|
||||
|
||||
add_executable(${TARGET} llama.cpp)
|
||||
add_executable(${TARGET} llama.cpp download.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <filesystem>
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv) {
|
||||
printf(
|
||||
"\nexamples:\n"
|
||||
" %s -hf ggml-org/gemma-3-4b-it-qat-GGUF\n"
|
||||
" %s -hf ggml-org/gemma-3-4b-it-qat-GGUF:Q4_K_M\n"
|
||||
" %s -hf ggml-org/models -hff model.gguf\n"
|
||||
" %s -mu https://example.com/model.gguf -m model.gguf\n"
|
||||
"\n",
|
||||
argv[0], argv[0], argv[0], argv[0]
|
||||
);
|
||||
}
|
||||
|
||||
int llama_download(int argc, char ** argv);
|
||||
|
||||
int llama_download(int argc, char ** argv) {
|
||||
common_init();
|
||||
|
||||
common_params params;
|
||||
params.verbosity = LOG_LEVEL_ERROR;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DOWNLOAD, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const bool has_source = !params.model.hf_repo.empty() || !params.model.url.empty() ||
|
||||
!params.model.path.empty() || !params.model.docker_repo.empty();
|
||||
if (!has_source) {
|
||||
fprintf(stderr, "error: no model source specified (use --hf-repo, --model-url, --model or --docker-repo)\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
try {
|
||||
common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_DOWNLOAD);
|
||||
common_models_handler_apply(handler, params);
|
||||
} catch (const std::exception & e) {
|
||||
fprintf(stderr, "error: %s\n", e.what());
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!params.models_preset.empty()) {
|
||||
// -hf pointed at a preset repo: print the preset path and stop
|
||||
printf("%s\n", params.models_preset.c_str());
|
||||
return 0;
|
||||
}
|
||||
if (params.model.path.empty()) {
|
||||
fprintf(stderr, "error: model download failed\n");
|
||||
return 1;
|
||||
}
|
||||
if (!std::filesystem::exists(params.model.path)) {
|
||||
fprintf(stderr, "error: model file does not exist: %s\n", params.model.path.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
printf("%s\n", params.model.path.c_str());
|
||||
if (!params.mmproj.path.empty()) {
|
||||
printf("%s\n", params.mmproj.path.c_str());
|
||||
}
|
||||
if (!params.speculative.draft.mparams.path.empty()) {
|
||||
printf("%s\n", params.speculative.draft.mparams.path.c_str());
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
+10
-4
@@ -19,6 +19,7 @@ int llama_batched_bench(int argc, char ** argv);
|
||||
int llama_fit_params(int argc, char ** argv);
|
||||
int llama_quantize(int argc, char ** argv);
|
||||
int llama_perplexity(int argc, char ** argv);
|
||||
int llama_download(int argc, char ** argv);
|
||||
|
||||
// Self-update is only supported for binaries built with llama-install.sh
|
||||
static int llama_update(int argc, char ** argv) {
|
||||
@@ -49,6 +50,7 @@ struct command {
|
||||
std::vector<std::string> aliases;
|
||||
bool hidden;
|
||||
int (*func)(int, char **);
|
||||
bool flags = false; // allow --name
|
||||
};
|
||||
|
||||
#ifdef LLAMA_INSTALL_BUILD
|
||||
@@ -61,15 +63,16 @@ static const command cmds[] = {
|
||||
{"serve", "HTTP API server", {"server"}, false, llama_server },
|
||||
{"cli", "Command-line interactive interface", {"client"}, false, llama_cli },
|
||||
{"update", "Update llama to the latest release", {}, UPDATE_HIDDEN, llama_update },
|
||||
{"download", "Download a model", {"get"}, false, llama_download },
|
||||
{"completion", "Text completion", {"complete"}, true, llama_completion },
|
||||
{"bench", "Benchmark prompt processing and text generation", {}, true, llama_bench },
|
||||
{"batched-bench", "Benchmark batched decoding performance", {}, true, llama_batched_bench},
|
||||
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
|
||||
{"quantize", "Quantize a model", {}, true, llama_quantize },
|
||||
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
|
||||
{"version", "Show version", {}, false, version },
|
||||
{"licenses", "Show third-party licenses", {"credits"}, false, licenses },
|
||||
{"help", "Show available commands", {}, false, help },
|
||||
{"version", "Show version", {}, false, version, true },
|
||||
{"licenses", "Show third-party licenses", {"credits"}, false, licenses, true },
|
||||
{"help", "Show available commands", {}, false, help, true },
|
||||
};
|
||||
|
||||
#undef UPDATE_HIDDEN
|
||||
@@ -106,7 +109,10 @@ static int help(int argc, char ** argv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static bool matches(const std::string & arg, const command & cmd) {
|
||||
static bool matches(std::string arg, const command & cmd) {
|
||||
if (cmd.flags && arg.size() > 2 && arg[0] == '-' && arg[1] == '-') {
|
||||
arg.erase(0, 2);
|
||||
}
|
||||
if (arg == cmd.name) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ LLAMA_BUILD_EXAMPLES=OFF
|
||||
LLAMA_BUILD_TOOLS=OFF
|
||||
LLAMA_BUILD_TESTS=OFF
|
||||
LLAMA_BUILD_SERVER=OFF
|
||||
LLAMA_BUILD_MTMD=ON
|
||||
GGML_METAL=ON
|
||||
GGML_METAL_EMBED_LIBRARY=ON
|
||||
GGML_BLAS_DEFAULT=ON
|
||||
@@ -39,6 +40,7 @@ COMMON_CMAKE_ARGS=(
|
||||
-DLLAMA_BUILD_TOOLS=${LLAMA_BUILD_TOOLS}
|
||||
-DLLAMA_BUILD_TESTS=${LLAMA_BUILD_TESTS}
|
||||
-DLLAMA_BUILD_SERVER=${LLAMA_BUILD_SERVER}
|
||||
-DLLAMA_BUILD_MTMD=${LLAMA_BUILD_MTMD}
|
||||
-DGGML_METAL_EMBED_LIBRARY=${GGML_METAL_EMBED_LIBRARY}
|
||||
-DGGML_BLAS_DEFAULT=${GGML_BLAS_DEFAULT}
|
||||
-DGGML_METAL=${GGML_METAL}
|
||||
@@ -126,6 +128,8 @@ setup_framework_structure() {
|
||||
cp ggml/include/ggml-cpu.h ${header_path}
|
||||
cp ggml/include/ggml-blas.h ${header_path}
|
||||
cp ggml/include/gguf.h ${header_path}
|
||||
cp tools/mtmd/mtmd.h ${header_path}
|
||||
cp tools/mtmd/mtmd-helper.h ${header_path}
|
||||
|
||||
# Create module map (common for all platforms)
|
||||
cat > ${module_path}module.modulemap << EOF
|
||||
@@ -247,6 +251,7 @@ combine_static_libraries() {
|
||||
"${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-cpu.a"
|
||||
"${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a"
|
||||
"${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a"
|
||||
"${base_dir}/${build_dir}/tools/mtmd/${release_dir}/libmtmd.a"
|
||||
)
|
||||
|
||||
# Create temporary directory for processing
|
||||
@@ -410,6 +415,7 @@ cmake -B build-ios-sim -G Xcode \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-ios-sim --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
@@ -424,6 +430,7 @@ cmake -B build-ios-device -G Xcode \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-ios-device --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
@@ -450,6 +457,7 @@ cmake -B build-visionos -G Xcode \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
@@ -465,6 +473,7 @@ cmake -B build-visionos-sim -G Xcode \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DLLAMA_BUILD_SERVER=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-visionos-sim --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
@@ -481,6 +490,7 @@ cmake -B build-tvos-sim -G Xcode \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-tvos-sim --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
@@ -496,6 +506,7 @@ cmake -B build-tvos-device -G Xcode \
|
||||
-DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \
|
||||
-DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \
|
||||
-DLLAMA_OPENSSL=OFF \
|
||||
-DMTMD_VIDEO=OFF \
|
||||
-S .
|
||||
cmake --build build-tvos-device --config Release -j $(sysctl -n hw.logicalcpu) -- -quiet
|
||||
|
||||
|
||||
@@ -80,8 +80,6 @@ add_library(${TARGET}
|
||||
http.h
|
||||
imatrix-loader.cpp
|
||||
imatrix-loader.h
|
||||
json-partial.cpp
|
||||
json-partial.h
|
||||
json-schema-to-grammar.cpp
|
||||
llguidance.cpp
|
||||
log.cpp
|
||||
|
||||
+256
-120
@@ -297,58 +297,6 @@ struct handle_model_result {
|
||||
std::string preset_path;
|
||||
};
|
||||
|
||||
static handle_model_result common_params_handle_model(struct common_params_model & model,
|
||||
const common_download_opts & opts) {
|
||||
handle_model_result result;
|
||||
|
||||
if (!model.docker_repo.empty()) {
|
||||
model.path = common_docker_resolve_model(model.docker_repo);
|
||||
} else if (!model.hf_repo.empty()) {
|
||||
// If -m was used with -hf, treat the model "path" as the hf_file to download
|
||||
if (model.hf_file.empty() && !model.path.empty()) {
|
||||
model.hf_file = model.path;
|
||||
model.path = "";
|
||||
}
|
||||
common_download_opts hf_opts = opts;
|
||||
auto download_result = common_download_model(model, hf_opts);
|
||||
|
||||
if (!download_result.preset_path.empty()) {
|
||||
result.found_preset = true;
|
||||
result.preset_path = download_result.preset_path;
|
||||
return result; // skip everything else if preset.ini is used
|
||||
}
|
||||
|
||||
if (download_result.model_path.empty()) {
|
||||
throw std::runtime_error("failed to download model from Hugging Face");
|
||||
}
|
||||
|
||||
model.path = download_result.model_path;
|
||||
|
||||
if (!download_result.mmproj_path.empty()) {
|
||||
result.found_mmproj = true;
|
||||
result.mmproj.path = download_result.mmproj_path;
|
||||
}
|
||||
|
||||
if (!download_result.mtp_path.empty()) {
|
||||
result.found_mtp = true;
|
||||
result.mtp.path = download_result.mtp_path;
|
||||
}
|
||||
} else if (!model.url.empty()) {
|
||||
if (model.path.empty()) {
|
||||
auto f = string_split<std::string>(model.url, '#').front();
|
||||
f = string_split<std::string>(f, '?').front();
|
||||
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
|
||||
auto download_result = common_download_model(model, opts);
|
||||
if (download_result.model_path.empty()) {
|
||||
throw std::runtime_error("failed to download model from " + model.url);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::vector<ggml_type> kv_cache_types = {
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F16,
|
||||
@@ -393,72 +341,241 @@ static bool parse_bool_value(const std::string & value) {
|
||||
}
|
||||
|
||||
//
|
||||
// CLI argument parsing functions
|
||||
// common_models_handler
|
||||
//
|
||||
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex) {
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
static std::string get_default_local_path(const std::string & url) {
|
||||
auto f = string_split<std::string>(url, '#').front();
|
||||
f = string_split<std::string>(f, '?').front();
|
||||
return fs_get_cache_file(string_split<std::string>(f, '/').back());
|
||||
}
|
||||
|
||||
common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex) {
|
||||
common_download_hf_plan plan;
|
||||
common_download_hf_plan plan_spec;
|
||||
common_download_hf_plan plan_voc;
|
||||
common_download_opts opts;
|
||||
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
params.speculative.types.end(),
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end();
|
||||
|
||||
// only download mmproj if the current example is using it
|
||||
bool use_mmproj = false;
|
||||
for (const auto & ex : mmproj_examples) {
|
||||
if (curr_ex == ex) {
|
||||
use_mmproj = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
opts.bearer_token = params.hf_token;
|
||||
opts.offline = params.offline;
|
||||
opts.skip_download = params.skip_download;
|
||||
opts.download_mtp = spec_type_draft_mtp;
|
||||
opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
opts.download_mmproj = use_mmproj && !params.no_mmproj
|
||||
&& params.mmproj.path.empty() && params.mmproj.url.empty();
|
||||
|
||||
// sub-models (draft, mmproj, vocoder) are explicitly specified by the user,
|
||||
// so we should not auto-discover mtp/mmproj siblings for them
|
||||
common_download_opts sub_opts = opts;
|
||||
sub_opts.download_mtp = false;
|
||||
sub_opts.download_mmproj = false;
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
plan = common_download_get_hf_plan(params.model, opts);
|
||||
}
|
||||
|
||||
try {
|
||||
auto res = common_params_handle_model(params.model, opts);
|
||||
if (res.found_preset) {
|
||||
if (!params.models_preset.empty()) {
|
||||
throw std::invalid_argument("cannot use both --models-preset and -hf with a preset.ini file");
|
||||
if (!params.speculative.draft.mparams.hf_repo.empty()) {
|
||||
plan_spec = common_download_get_hf_plan(params.speculative.draft.mparams, opts);
|
||||
}
|
||||
|
||||
if (!params.vocoder.model.hf_repo.empty()) {
|
||||
plan_voc = common_download_get_hf_plan(params.vocoder.model, opts);
|
||||
}
|
||||
|
||||
return common_models_handler{plan, plan_spec, plan_voc, opts};
|
||||
}
|
||||
|
||||
bool common_models_handler_is_preset_repo(const common_models_handler & handler) {
|
||||
return !handler.plan.preset.url.empty();
|
||||
}
|
||||
|
||||
static std::vector<common_download_task> build_url_tasks(const common_params_model & model, common_download_opts opts) {
|
||||
auto parts = common_download_get_all_parts(model.url);
|
||||
std::vector<common_download_task> tasks;
|
||||
|
||||
// single-part: download straight to model.path if the user gave one (-m), else the cache default
|
||||
if (parts.size() == 1) {
|
||||
common_download_task task;
|
||||
task.url = parts[0];
|
||||
task.local_path = model.path.empty() ? get_default_local_path(parts[0]) : model.path;
|
||||
task.opts = opts;
|
||||
tasks.push_back(std::move(task));
|
||||
return tasks;
|
||||
}
|
||||
|
||||
// multi-part: place each part under the user's -m directory (if given), else the cache default
|
||||
std::string base_dir;
|
||||
if (!model.path.empty()) {
|
||||
auto pos = model.path.rfind('/');
|
||||
base_dir = pos == std::string::npos ? std::string(".") : model.path.substr(0, pos);
|
||||
}
|
||||
|
||||
for (const auto & part : parts) {
|
||||
common_download_task task;
|
||||
task.url = part;
|
||||
task.opts = opts;
|
||||
|
||||
std::string local = get_default_local_path(part);
|
||||
if (!base_dir.empty()) {
|
||||
auto pos = local.rfind('/');
|
||||
std::string name = pos == std::string::npos ? local : local.substr(pos + 1);
|
||||
local = base_dir + "/" + name;
|
||||
}
|
||||
task.local_path = local;
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
return tasks;
|
||||
}
|
||||
|
||||
void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback) {
|
||||
std::vector<common_download_task> tasks;
|
||||
|
||||
auto & plan = handler.plan;
|
||||
auto & plan_spec = handler.plan_spec;
|
||||
auto & plan_voc = handler.plan_voc;
|
||||
|
||||
auto opts = handler.opts; // copy
|
||||
opts.callback = callback;
|
||||
|
||||
// handle plain "url" if needed
|
||||
auto handle_url = [&](common_params_model & model) {
|
||||
if (!model.url.empty()) {
|
||||
if (model.path.empty()) {
|
||||
model.path = get_default_local_path(model.url);
|
||||
}
|
||||
}
|
||||
};
|
||||
handle_url(params.model);
|
||||
handle_url(params.mmproj);
|
||||
handle_url(params.vocoder.model);
|
||||
handle_url(params.speculative.draft.mparams);
|
||||
|
||||
// optionally, if docker repo is set, resolve it
|
||||
if (!params.model.docker_repo.empty()) {
|
||||
params.model.url = common_docker_resolve_model(params.model.docker_repo);
|
||||
params.model.path = get_default_local_path(params.model.url);
|
||||
}
|
||||
|
||||
// handle plain "url" tasks (non-hf)
|
||||
if (!params.model.url.empty()) {
|
||||
auto url_tasks = build_url_tasks(params.model, opts);
|
||||
// the first part is what gets loaded, so point params.model.path at it
|
||||
if (!url_tasks.empty()) {
|
||||
std::string first_path = url_tasks.front().local_path;
|
||||
url_tasks.front().on_done = [&]() { params.model.path = first_path; };
|
||||
}
|
||||
for (auto & task : url_tasks) {
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
}
|
||||
if (!params.mmproj.url.empty()) {
|
||||
common_download_task task;
|
||||
task.url = params.mmproj.url;
|
||||
task.local_path = params.mmproj.path;
|
||||
task.opts = opts;
|
||||
tasks.push_back(task);
|
||||
}
|
||||
if (!params.vocoder.model.url.empty()) {
|
||||
common_download_task task;
|
||||
task.url = params.vocoder.model.url;
|
||||
task.local_path = params.vocoder.model.path;
|
||||
task.opts = opts;
|
||||
tasks.push_back(task);
|
||||
}
|
||||
if (!params.speculative.draft.mparams.url.empty()) {
|
||||
common_download_task task;
|
||||
task.url = params.speculative.draft.mparams.url;
|
||||
task.local_path = params.speculative.draft.mparams.path;
|
||||
task.opts = opts;
|
||||
tasks.push_back(task);
|
||||
}
|
||||
|
||||
// handle hf_plan tasks
|
||||
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files, common_params_model & model) {
|
||||
for (size_t i = 0; i < model_files.size(); ++i) {
|
||||
auto & model_file = model_files[i];
|
||||
bool is_first = (i == 0);
|
||||
tasks.emplace_back(model_file, opts, [&, is_first]() {
|
||||
if (is_first) {
|
||||
// only use first part as model path
|
||||
model.path = hf_cache::finalize_file(model_file);
|
||||
} else {
|
||||
hf_cache::finalize_file(model_file);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
if (!plan.model_files.empty()) {
|
||||
add_tasks(plan.model_files, params.model);
|
||||
}
|
||||
if (!plan.mmproj.local_path.empty()) {
|
||||
tasks.emplace_back(plan.mmproj, opts, [&]() {
|
||||
params.mmproj.path = hf_cache::finalize_file(plan.mmproj);
|
||||
});
|
||||
}
|
||||
if (!plan.mtp.local_path.empty()) {
|
||||
tasks.emplace_back(plan.mtp, opts, [&]() {
|
||||
// only fall back to the discovered MTP head when no draft was explicitly provided
|
||||
if (params.speculative.draft.mparams.empty()) {
|
||||
params.speculative.draft.mparams.path = hf_cache::finalize_file(plan.mtp);
|
||||
} else {
|
||||
hf_cache::finalize_file(plan.mtp);
|
||||
}
|
||||
});
|
||||
}
|
||||
if (!plan.preset.local_path.empty()) {
|
||||
tasks.emplace_back(plan.preset, opts, [&]() {
|
||||
// if HF repo is a preset repo, we simply run server in router mode with the preset.ini file
|
||||
params.models_preset_hf = params.model.hf_repo; // only for showing a warning
|
||||
params.models_preset = res.preset_path;
|
||||
params.models_preset = hf_cache::finalize_file(plan.preset);
|
||||
params.model = common_params_model{}; // make sure to clear model, so server starts in router mode
|
||||
return true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (params.no_mmproj) {
|
||||
params.mmproj = {};
|
||||
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
|
||||
// optionally, handle mmproj model when -hf is specified
|
||||
params.mmproj = res.mmproj;
|
||||
}
|
||||
// only download mmproj if the current example is using it
|
||||
for (const auto & ex : mmproj_examples) {
|
||||
if (curr_ex == ex) {
|
||||
common_params_handle_model(params.mmproj, sub_opts);
|
||||
break;
|
||||
// handle plan_spec (e.g. --spec-draft-hf)
|
||||
if (!plan_spec.model_files.empty()) {
|
||||
add_tasks(plan_spec.model_files, params.speculative.draft.mparams);
|
||||
}
|
||||
|
||||
// handle vocoder plan (e.g. --hf-repo-v)
|
||||
if (!plan_voc.model_files.empty()) {
|
||||
add_tasks(plan_voc.model_files, params.vocoder.model);
|
||||
}
|
||||
|
||||
// run all tasks in parallel
|
||||
if (!params.offline) {
|
||||
// if duplicated files are found, only download once (but still call on_done for each task)
|
||||
std::unordered_map<std::string, common_download_task *> unique_tasks;
|
||||
for (auto & task : tasks) {
|
||||
auto it = unique_tasks.find(task.local_path);
|
||||
if (it == unique_tasks.end()) {
|
||||
unique_tasks[task.local_path] = &task;
|
||||
}
|
||||
}
|
||||
|
||||
// when --spec-type mtp is set and no draft model was provided explicitly,
|
||||
// fall back to the MTP head discovered alongside the -hf model
|
||||
if (spec_type_draft_mtp && res.found_mtp &&
|
||||
params.speculative.draft.mparams.path.empty() &&
|
||||
params.speculative.draft.mparams.hf_repo.empty() &&
|
||||
params.speculative.draft.mparams.url.empty()) {
|
||||
params.speculative.draft.mparams.path = res.mtp.path;
|
||||
std::vector<common_download_task> unique_tasks_vec;
|
||||
for (auto & pair : unique_tasks) {
|
||||
unique_tasks_vec.push_back(*pair.second);
|
||||
}
|
||||
common_download_run_tasks(unique_tasks_vec);
|
||||
}
|
||||
|
||||
// download successful, update params with the downloaded paths
|
||||
for (const auto & task : tasks) {
|
||||
if (task.on_done) {
|
||||
task.on_done();
|
||||
}
|
||||
common_params_handle_model(params.speculative.draft.mparams, sub_opts);
|
||||
common_params_handle_model(params.vocoder.model, sub_opts);
|
||||
return true;
|
||||
} catch (const common_skip_download_exception &) {
|
||||
return false;
|
||||
} catch (const std::exception &) {
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// CLI argument parsing functions
|
||||
//
|
||||
|
||||
static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) {
|
||||
common_params & params = ctx_arg.params;
|
||||
|
||||
@@ -584,17 +701,22 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
|
||||
}
|
||||
|
||||
// export_graph_ops loads only metadata
|
||||
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
const bool skip_model_download =
|
||||
// server will call common_params_handle_models() later, so we skip it here
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_SERVER ||
|
||||
// download calls common_params_handle_models() itself and prints the paths
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_DOWNLOAD ||
|
||||
// export_graph_ops loads only metadata
|
||||
ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
|
||||
|
||||
if (!skip_model_download) {
|
||||
// handle model and download
|
||||
common_params_handle_models(params, ctx_arg.ex);
|
||||
common_models_handler handler = common_models_handler_init(params, ctx_arg.ex);
|
||||
common_models_handler_apply(handler, params);
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
if (params.model.path.empty()
|
||||
&& ctx_arg.ex != LLAMA_EXAMPLE_SERVER
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
@@ -662,15 +784,19 @@ static void common_params_print_usage(common_params_context & ctx_arg) {
|
||||
common_options.push_back(&opt);
|
||||
}
|
||||
}
|
||||
printf("----- common params -----\n\n");
|
||||
print_options(common_options);
|
||||
printf("\n\n----- sampling params -----\n\n");
|
||||
print_options(sampling_options);
|
||||
printf("\n\n----- speculative params -----\n\n");
|
||||
print_options(spec_options);
|
||||
// TODO: maybe convert enum llama_example to string
|
||||
printf("\n\n----- example-specific params -----\n\n");
|
||||
print_options(specific_options);
|
||||
bool first = true;
|
||||
auto print_section = [&](const char * header, std::vector<common_arg *> & options) {
|
||||
if (options.empty()) {
|
||||
return;
|
||||
}
|
||||
printf("%s----- %s -----\n\n", first ? "" : "\n\n", header);
|
||||
first = false;
|
||||
print_options(options);
|
||||
};
|
||||
print_section("common params", common_options);
|
||||
print_section("sampling params", sampling_options);
|
||||
print_section("speculative params", spec_options);
|
||||
print_section("example-specific params", specific_options);
|
||||
}
|
||||
|
||||
static void common_params_print_completion(common_params_context & ctx_arg) {
|
||||
@@ -1070,7 +1196,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
* - if both {LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_*,} are set, we will prioritize the LLAMA_EXAMPLE_* matching current example
|
||||
*/
|
||||
auto add_opt = [&](common_arg arg) {
|
||||
if ((arg.in_example(ex) || arg.in_example(LLAMA_EXAMPLE_COMMON)) && !arg.is_exclude(ex)) {
|
||||
// download only exposes the handful of args explicitly tagged for it
|
||||
const bool inherit_common = ex != LLAMA_EXAMPLE_DOWNLOAD;
|
||||
if ((arg.in_example(ex) || (inherit_common && arg.in_example(LLAMA_EXAMPLE_COMMON))) && !arg.is_exclude(ex)) {
|
||||
ctx_arg.options.push_back(std::move(arg));
|
||||
}
|
||||
};
|
||||
@@ -1081,7 +1209,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.usage = true;
|
||||
}
|
||||
));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}));
|
||||
add_opt(common_arg(
|
||||
{"--version"},
|
||||
"show version and build info",
|
||||
@@ -2203,7 +2331,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, bool value) {
|
||||
params.no_mmproj = !value;
|
||||
}
|
||||
).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_AUTO"));
|
||||
).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MMPROJ_AUTO"));
|
||||
add_opt(common_arg(
|
||||
{"--mmproj-offload"},
|
||||
{"--no-mmproj-offload"},
|
||||
@@ -2602,14 +2730,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.path = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MODEL"));
|
||||
add_opt(common_arg(
|
||||
{"-mu", "--model-url"}, "MODEL_URL",
|
||||
"model download url (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.url = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_MODEL_URL"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MODEL_URL"));
|
||||
add_opt(common_arg(
|
||||
{ "-dr", "--docker-repo" }, "[<repo>/]<model>[:quant]",
|
||||
"Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n"
|
||||
@@ -2618,7 +2746,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.docker_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_DOCKER_REPO"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_DOCKER_REPO"));
|
||||
add_opt(common_arg(
|
||||
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
|
||||
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
|
||||
@@ -2628,14 +2756,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.hf_repo = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HF_REPO"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_HF_REPO"));
|
||||
add_opt(common_arg(
|
||||
{"-hff", "--hf-file"}, "FILE",
|
||||
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.model.hf_file = value;
|
||||
}
|
||||
).set_env("LLAMA_ARG_HF_FILE"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_HF_FILE"));
|
||||
add_opt(common_arg(
|
||||
{"-hfv", "-hfrv", "--hf-repo-v"}, "<user>/<model>[:quant]",
|
||||
"Hugging Face model repository for the vocoder model (default: unused)",
|
||||
@@ -2656,7 +2784,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.hf_token = value;
|
||||
}
|
||||
).set_env("HF_TOKEN"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("HF_TOKEN"));
|
||||
add_opt(common_arg(
|
||||
{"--mtp"},
|
||||
"also download the multi-token prediction (MTP) head, if available (default: unused)",
|
||||
[](common_params & params) {
|
||||
params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_DRAFT_MTP);
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_DOWNLOAD}));
|
||||
add_opt(common_arg(
|
||||
{"--context-file"}, "FNAME",
|
||||
"file to load context from (use comma-separated values to specify multiple files)",
|
||||
@@ -3613,6 +3748,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
"draft model for speculative decoding (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.draft.mparams.path = value;
|
||||
params.speculative.draft.mparams.hf_file = value; // will be used if --spec-draft-hf is set
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
|
||||
add_opt(common_arg(
|
||||
|
||||
+17
-5
@@ -1,12 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "download.h"
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
// pseudo-env variable to identify preset-only arguments
|
||||
#define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP"
|
||||
@@ -129,11 +131,21 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<com
|
||||
// see: https://github.com/ggml-org/llama.cpp/issues/18163
|
||||
void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
|
||||
// populate model paths (main model, mmproj, etc) from -hf if necessary
|
||||
// return true if the model is ready to use
|
||||
// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc)
|
||||
// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed)
|
||||
bool common_params_handle_models(common_params & params, llama_example curr_ex);
|
||||
struct common_models_handler {
|
||||
common_download_hf_plan plan;
|
||||
common_download_hf_plan plan_spec;
|
||||
common_download_hf_plan plan_voc;
|
||||
common_download_opts opts;
|
||||
};
|
||||
|
||||
// initialize downloading opts and hf_plan if needed, but does not download anything yet
|
||||
common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex);
|
||||
|
||||
// check if the model is a preset repo (i.e. has a preset file)
|
||||
bool common_models_handler_is_preset_repo(const common_models_handler & handler);
|
||||
|
||||
// download and update params with the downloaded model path
|
||||
void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback = nullptr);
|
||||
|
||||
// initialize argument parser context - used by test-arg-parser and preset
|
||||
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);
|
||||
|
||||
@@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
|
||||
arguments.name_suffix) +
|
||||
arguments.value_prefix +
|
||||
(schema_info.resolves_to_string(param_schema) ?
|
||||
p.tool_arg_string_value(until_suffix) :
|
||||
p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)));
|
||||
p.ac(p.tool_arg_string_value(until_suffix) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) :
|
||||
(p.tool_arg_json_value(p.schema(
|
||||
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
|
||||
p.tool_arg_close(p.literal(arguments.value_suffix)))));
|
||||
|
||||
auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg);
|
||||
if (is_required) {
|
||||
|
||||
+107
-53
@@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
|
||||
return text;
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
|
||||
if (delims.empty() || prompt.empty()) {
|
||||
return {};
|
||||
common_chat_role common_chat_role_from_string(const std::string & role) {
|
||||
if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; }
|
||||
if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; }
|
||||
if (role == "user") { return COMMON_CHAT_ROLE_USER; }
|
||||
if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; }
|
||||
return COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
|
||||
const char * common_chat_role_to_string(common_chat_role role) {
|
||||
switch (role) {
|
||||
case COMMON_CHAT_ROLE_SYSTEM: return "system";
|
||||
case COMMON_CHAT_ROLE_ASSISTANT: return "assistant";
|
||||
case COMMON_CHAT_ROLE_USER: return "user";
|
||||
case COMMON_CHAT_ROLE_TOOL: return "tool";
|
||||
case COMMON_CHAT_ROLE_UNKNOWN: return "";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
json common_chat_msg_delimiters::to_json() const {
|
||||
json result = json::array();
|
||||
for (const auto & d : delimiters) {
|
||||
result.push_back({
|
||||
{ "role", common_chat_role_to_string(d.role) },
|
||||
{ "delimiter", d.delimiter },
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) {
|
||||
common_chat_msg_delimiters result;
|
||||
|
||||
if (!delimiters.is_array()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
|
||||
std::vector<std::string> all_delims;
|
||||
std::vector<common_peg_parser> tagged_messages;
|
||||
|
||||
all_delims.reserve(delims.size());
|
||||
tagged_messages.reserve(delims.size());
|
||||
for (const auto & d : delims) {
|
||||
all_delims.push_back(d.delimiter);
|
||||
result.delimiters.reserve(delimiters.size());
|
||||
for (const auto & d : delimiters) {
|
||||
if (!d.is_object()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto any_delim = p.until_one_of(all_delims);
|
||||
for (const auto & d : delims) {
|
||||
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
|
||||
}
|
||||
|
||||
return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
|
||||
});
|
||||
|
||||
common_peg_parse_context ctx(prompt);
|
||||
const auto result = parser.parse(ctx);
|
||||
if (!result.success()) {
|
||||
return {};
|
||||
result.delimiters.push_back({
|
||||
common_chat_role_from_string(d.value("role", std::string())),
|
||||
d.value("delimiter", std::string()),
|
||||
});
|
||||
}
|
||||
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
|
||||
if (!node.tag.empty()) {
|
||||
spans.push_back({ node.tag, node.start, node.end - node.start });
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) {
|
||||
for (auto & d : delimiters) {
|
||||
d.tokens = common_tokenize(vocab, d.delimiter, false, true);
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips) const {
|
||||
std::vector<std::pair<common_chat_role, size_t>> matches;
|
||||
|
||||
auto skip = skips.begin();
|
||||
for (size_t i = 0; i < tokens.size();) {
|
||||
if (skip != skips.end() && i == skip->first) {
|
||||
i += skip->second;
|
||||
++skip;
|
||||
continue;
|
||||
}
|
||||
});
|
||||
for (const auto & d : delimiters) {
|
||||
if (i + d.tokens.size() > tokens.size()) {
|
||||
continue;
|
||||
}
|
||||
if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) {
|
||||
matches.emplace_back(d.role, i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size());
|
||||
|
||||
common_chat_msg_spans spans;
|
||||
for (size_t i = 0; i + 1 < matches.size(); i++) {
|
||||
const auto & curr = matches[i];
|
||||
const auto & next = matches[i + 1];
|
||||
spans.add(curr.first, curr.second, next.second - curr.second);
|
||||
}
|
||||
|
||||
return spans;
|
||||
}
|
||||
@@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
|
||||
|
||||
data.prompt = prompt;
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
|
||||
data.message_spans = common_chat_split_by_role(prompt, {
|
||||
{ "assistant", "<|start|>assistant" },
|
||||
{ "user", "<|start|>user" },
|
||||
{ "system", "<|start|>developer" },
|
||||
{ "system", "<|start|>system" },
|
||||
{ "tool", "<|start|>functions" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|start|>functions" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
@@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "user", "<|turn>user\n" },
|
||||
{ "assistant", "<|turn>model\n" },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_USER, "<|turn>user" },
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" },
|
||||
};
|
||||
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
|
||||
data.supports_thinking = true;
|
||||
@@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t
|
||||
RESULT_START, RESULT_END,
|
||||
};
|
||||
|
||||
// Split the rendered prompt into per-role message spans. Tool results are rendered with the
|
||||
// Declare per-role message delimiters. Tool results are rendered with the
|
||||
// system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before
|
||||
// the plain "system" one (it is a strict superset, and the role split tries delimiters in order).
|
||||
data.message_spans = common_chat_split_by_role(data.prompt, {
|
||||
{ "assistant", GEN_PREFIX },
|
||||
{ "user", TURN_START + USER },
|
||||
{ "tool", TURN_START + SYSTEM + RESULT_START },
|
||||
{ "system", TURN_START + SYSTEM },
|
||||
});
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX },
|
||||
{ COMMON_CHAT_ROLE_USER, TURN_START + USER },
|
||||
{ COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM },
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
@@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
|
||||
autoparser.analyze_template(tmpl);
|
||||
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);
|
||||
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
common_chat_msg_delimiters delimiters;
|
||||
if (!autoparser.assistant_start.empty()) {
|
||||
delimiters.push_back({ "assistant", autoparser.assistant_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start);
|
||||
}
|
||||
if (!autoparser.user_start.empty()) {
|
||||
delimiters.push_back({ "user", autoparser.user_start });
|
||||
delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start);
|
||||
}
|
||||
|
||||
if (!delimiters.empty()) {
|
||||
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
|
||||
}
|
||||
auto_params.message_delimiters = std::move(delimiters);
|
||||
|
||||
auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
|
||||
if (auto_params.supports_thinking) {
|
||||
@@ -2708,5 +2758,9 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
|
||||
std::map<std::string, bool> common_chat_templates_get_caps(const common_chat_templates * chat_templates) {
|
||||
GGML_ASSERT(chat_templates != nullptr);
|
||||
GGML_ASSERT(chat_templates->template_default != nullptr);
|
||||
if (chat_templates->template_tool_use != nullptr) {
|
||||
// take the more expressive template when available
|
||||
return chat_templates->template_tool_use->caps.to_map();
|
||||
}
|
||||
return chat_templates->template_default->caps.to_map();
|
||||
}
|
||||
|
||||
+65
-6
@@ -143,15 +143,75 @@ struct common_chat_msg_diff {
|
||||
}
|
||||
};
|
||||
|
||||
enum common_chat_role {
|
||||
COMMON_CHAT_ROLE_UNKNOWN,
|
||||
COMMON_CHAT_ROLE_SYSTEM,
|
||||
COMMON_CHAT_ROLE_ASSISTANT,
|
||||
COMMON_CHAT_ROLE_USER,
|
||||
COMMON_CHAT_ROLE_TOOL
|
||||
};
|
||||
|
||||
common_chat_role common_chat_role_from_string(const std::string & role);
|
||||
const char * common_chat_role_to_string(common_chat_role role);
|
||||
|
||||
struct common_chat_msg_span {
|
||||
std::string role;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::size_t pos = 0;
|
||||
std::size_t len = 0;
|
||||
|
||||
bool valid() const {
|
||||
return role != COMMON_CHAT_ROLE_UNKNOWN;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_spans {
|
||||
std::vector<common_chat_msg_span> spans;
|
||||
|
||||
void add(common_chat_role role, size_t pos, size_t len) {
|
||||
spans.push_back({ role, pos, len });
|
||||
}
|
||||
|
||||
bool is_user_start(int32_t pos) const {
|
||||
for (auto it = spans.begin(); it != spans.end(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t last_user_message_pos() const {
|
||||
for (auto it = spans.rbegin(); it != spans.rend(); ++it) {
|
||||
if (it->role == COMMON_CHAT_ROLE_USER) {
|
||||
return (int32_t) it->pos;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiter {
|
||||
std::string role;
|
||||
std::string delimiter;
|
||||
common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN;
|
||||
std::string delimiter;
|
||||
llama_tokens tokens = {};
|
||||
};
|
||||
|
||||
struct common_chat_msg_delimiters {
|
||||
std::vector<common_chat_msg_delimiter> delimiters;
|
||||
|
||||
common_chat_msg_delimiters() = default;
|
||||
common_chat_msg_delimiters(std::initializer_list<common_chat_msg_delimiter> delims) : delimiters(delims) {}
|
||||
|
||||
void add(common_chat_role role, const std::string & delimiter) {
|
||||
delimiters.push_back({ role, delimiter });
|
||||
}
|
||||
|
||||
void tokenize(const llama_vocab * vocab);
|
||||
|
||||
// split tokens into message spans. skips maps a start index to a length of a region to jump over without matching
|
||||
common_chat_msg_spans split(const llama_tokens & tokens, const std::map<size_t, size_t> & skips = {}) const;
|
||||
|
||||
nlohmann::ordered_json to_json() const;
|
||||
};
|
||||
|
||||
struct common_chat_tool {
|
||||
@@ -219,7 +279,7 @@ struct common_chat_params {
|
||||
std::vector<std::string> preserved_tokens;
|
||||
std::vector<std::string> additional_stops;
|
||||
std::string parser;
|
||||
std::vector<common_chat_msg_span> message_spans;
|
||||
common_chat_msg_delimiters message_delimiters;
|
||||
};
|
||||
|
||||
// per-message parsing syntax
|
||||
@@ -325,5 +385,4 @@ struct common_chat_prompt_preset {
|
||||
|
||||
common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates);
|
||||
|
||||
std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims);
|
||||
|
||||
common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters);
|
||||
|
||||
+47
-47
@@ -225,7 +225,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
}
|
||||
|
||||
if (!SetPriorityClass(GetCurrentProcess(), p)) {
|
||||
LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
|
||||
COM_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -251,7 +251,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
}
|
||||
|
||||
if (setpriority(PRIO_PROCESS, 0, p) != 0) {
|
||||
LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
|
||||
COM_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@@ -284,14 +284,14 @@ void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_para
|
||||
|
||||
if (n_set && n_set < cpuparams.n_threads) {
|
||||
// Not enough set bits, may experience performance issues.
|
||||
LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
|
||||
COM_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
|
||||
size_t dash_loc = range.find('-');
|
||||
if (dash_loc == std::string::npos) {
|
||||
LOG_ERR("Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
|
||||
COM_ERR("%s", "Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -303,7 +303,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
|
||||
} else {
|
||||
start_i = std::stoull(range.substr(0, dash_loc));
|
||||
if (start_i >= GGML_MAX_N_THREADS) {
|
||||
LOG_ERR("Start index out of bounds!\n");
|
||||
COM_ERR("%s", "Start index out of bounds!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -313,7 +313,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
|
||||
} else {
|
||||
end_i = std::stoull(range.substr(dash_loc + 1));
|
||||
if (end_i >= GGML_MAX_N_THREADS) {
|
||||
LOG_ERR("End index out of bounds!\n");
|
||||
COM_ERR("%s", "End index out of bounds!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -333,7 +333,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
size_t num_digits = mask.length() - start_i;
|
||||
if (num_digits > 128) num_digits = 128;
|
||||
num_digits = std::min<size_t>(num_digits, 128);
|
||||
|
||||
size_t end_i = num_digits + start_i;
|
||||
|
||||
@@ -348,7 +348,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
} else if (c >= 'A' && c <= 'F') {
|
||||
id -= 'A' - 10;
|
||||
} else {
|
||||
LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
|
||||
COM_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -379,21 +379,21 @@ void common_params_print_info(const common_params & params, bool print_devices)
|
||||
#else
|
||||
const char * build_type = " (debug)";
|
||||
#endif
|
||||
LOG_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
|
||||
COM_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
|
||||
|
||||
LOG_INF("log_info: verbosity = %d (adjust with the `-lv N` CLI arg)\n", common_log_get_verbosity_thold());
|
||||
COM_INF("%s: verbosity = %d (adjust with the `-lv N` CLI arg)\n", __func__, common_log_get_verbosity_thold());
|
||||
|
||||
// device enumeration creates a primary context on CUDA backends, skip it when the caller does not own any device
|
||||
if (print_devices) {
|
||||
LOG_INF("device_info:\n");
|
||||
COM_TRC("%s", "device_info:\n");
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
LOG_INF(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
COM_TRC(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
COM_TRC("%s\n", common_params_get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
std::string common_params_get_system_info(const common_params & params) {
|
||||
@@ -660,7 +660,7 @@ void string_process_escapes(std::string & input) {
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
|
||||
const char * sep = strchr(data, '=');
|
||||
if (sep == nullptr || sep - data >= 128) {
|
||||
LOG_ERR("%s: malformed KV override '%s'\n", __func__, data);
|
||||
COM_ERR("%s: malformed KV override '%s'\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
llama_model_kv_override kvo;
|
||||
@@ -683,20 +683,20 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
||||
} else if (std::strcmp(sep, "false") == 0) {
|
||||
kvo.val_bool = false;
|
||||
} else {
|
||||
LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
|
||||
COM_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
} else if (strncmp(sep, "str:", 4) == 0) {
|
||||
sep += 4;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
|
||||
if (strlen(sep) > 127) {
|
||||
LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
|
||||
COM_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
strncpy(kvo.val_str, sep, 127);
|
||||
kvo.val_str[127] = '\0';
|
||||
} else {
|
||||
LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
|
||||
COM_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
overrides.emplace_back(std::move(kvo));
|
||||
@@ -1199,8 +1199,8 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory ...\n", __func__);
|
||||
LOG_INF("%s: (for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n", __func__);
|
||||
COM_TRC("%s", "fitting params to device memory ...\n");
|
||||
COM_TRC("%s", "(for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n");
|
||||
common_fit_params(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split,
|
||||
params.tensor_buft_overrides.data(),
|
||||
@@ -1227,7 +1227,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
llama_adapter_lora_ptr lora;
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
if (lora == nullptr) {
|
||||
LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
COM_ERR("failed to load lora adapter '%s'\n", la.path.c_str());
|
||||
pimpl->model.reset(model);
|
||||
return;
|
||||
}
|
||||
@@ -1246,14 +1246,14 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have an EOS token, ignoring --ignore-eos\n");
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_TRC("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||
COM_TRC("added %s logit bias = %f\n", common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
@@ -1291,7 +1291,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1328,7 +1328,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
|
||||
llama_model * model = res->model();
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||
COM_ERR("failed to load model '%s'\n", params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -1338,14 +1338,14 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
|
||||
llama_context * lctx = res->context();
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||
COM_WRN("%s", "KV cache shifting is not supported for this context, disabling KV cache shifting\n");
|
||||
params.ctx_shift = false;
|
||||
}
|
||||
|
||||
@@ -1374,7 +1374,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
bool ok = true;
|
||||
|
||||
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have a BOS token, reranking will not work\n");
|
||||
ok = false;
|
||||
}
|
||||
|
||||
@@ -1383,10 +1383,10 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
|
||||
|
||||
if (!has_eos && !has_sep && !has_rerank_prompt) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n");
|
||||
ok = false;
|
||||
} else if (!has_eos) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have an EOS token, using SEP token as fallback\n");
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
@@ -1399,7 +1399,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG_INF("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
COM_TRC("%s", "warming up the model with an empty run - please wait ... (--no-warmup to disable)\n");
|
||||
|
||||
std::vector<llama_token> tmp;
|
||||
llama_token bos = llama_vocab_bos(vocab);
|
||||
@@ -1473,20 +1473,20 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
|
||||
|
||||
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
|
||||
COM_ERR("llama_decode() failed: %d\n", ret);
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
goto done;
|
||||
}
|
||||
|
||||
if (llama_n_rs_seq(ctx) > 0) {
|
||||
LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__);
|
||||
COM_TRC("%s", "the context supports bounded partial sequence removal\n");
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_RS;
|
||||
goto done;
|
||||
}
|
||||
|
||||
// try to remove the last tokens
|
||||
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
||||
LOG_TRC("%s: the context does not support partial sequence removal\n", __func__);
|
||||
COM_TRC("%s", "the context does not support partial sequence removal\n");
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
goto done;
|
||||
}
|
||||
@@ -1803,13 +1803,13 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
};
|
||||
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
|
||||
if (!ctx_gguf) {
|
||||
LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("failed to load control vector file from %s\n", load_info.fname.c_str());
|
||||
return result;
|
||||
}
|
||||
|
||||
int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
|
||||
if (n_tensors == 0) {
|
||||
LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_WRN("no direction tensors found in %s\n", load_info.fname.c_str());
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_tensors; i++) {
|
||||
@@ -1827,23 +1827,23 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
}
|
||||
}
|
||||
if (layer_idx < 0) {
|
||||
LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid/unparsable direction tensor layer index in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
} else if (layer_idx == 0) {
|
||||
LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid (zero) direction tensor layer index in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
|
||||
if (tensor->type != GGML_TYPE_F32) {
|
||||
LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid (non-F32) direction tensor type in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
if (ggml_n_dims(tensor) != 1) {
|
||||
LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid (non-1D) direction tensor shape in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
@@ -1851,7 +1851,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
if (result.n_embd == -1) {
|
||||
result.n_embd = ggml_nelements(tensor);
|
||||
} else if (ggml_nelements(tensor) != result.n_embd) {
|
||||
LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("direction tensor in %s does not match previous dimensions\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
@@ -1868,7 +1868,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
}
|
||||
|
||||
if (result.n_embd == -1) {
|
||||
LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
|
||||
COM_WRN("skipping %s due to invalid direction tensors\n", load_info.fname.c_str());
|
||||
result.data.clear();
|
||||
}
|
||||
|
||||
@@ -1889,7 +1889,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
||||
break;
|
||||
}
|
||||
if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
|
||||
LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
|
||||
COM_ERR("control vectors in %s does not match previous dimensions\n", info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
@@ -1905,7 +1905,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
||||
}
|
||||
|
||||
if (result.n_embd == -1) {
|
||||
LOG_ERR("%s: no valid control vector files passed\n", __func__);
|
||||
COM_ERR("%s", "no valid control vector files passed\n");
|
||||
result.data.clear();
|
||||
}
|
||||
|
||||
@@ -2016,13 +2016,13 @@ bool common_prompt_batch_decode(
|
||||
// memory, so we can't just remove the last token from the memory and replay the last token which
|
||||
// is the reason for this logic.
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_tokens_before_last))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
COM_ERR("%s", "failed to eval\n");
|
||||
return false;
|
||||
}
|
||||
n_past += n_tokens_before_last;
|
||||
|
||||
llama_state_save_file(ctx, state_path.data(), all_tokens.data(), all_tokens.size());
|
||||
LOG_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
|
||||
COM_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
|
||||
|
||||
llama_token last_token = all_tokens.back();
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
@@ -2030,13 +2030,13 @@ bool common_prompt_batch_decode(
|
||||
batch.pos = &pos;
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval last token\n", __func__);
|
||||
COM_ERR("%s", "failed to eval last token\n");
|
||||
return false;
|
||||
}
|
||||
n_past++;
|
||||
} else {
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_new))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
COM_ERR("%s", "failed to eval\n");
|
||||
return false;
|
||||
}
|
||||
n_past += n_new;
|
||||
|
||||
+20
-9
@@ -25,6 +25,13 @@
|
||||
#define DIRECTORY_SEPARATOR '/'
|
||||
#endif // _WIN32
|
||||
|
||||
#define COM_DBG(fmt, ...) LOG_DBG("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_TRC(fmt, ...) LOG_TRC("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_INF(fmt, ...) LOG_INF("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_WRN(fmt, ...) LOG_WRN("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_ERR(fmt, ...) LOG_ERR("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
|
||||
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
|
||||
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
|
||||
|
||||
@@ -96,6 +103,7 @@ enum llama_example {
|
||||
LLAMA_EXAMPLE_FIT_PARAMS,
|
||||
LLAMA_EXAMPLE_RESULTS,
|
||||
LLAMA_EXAMPLE_EXPORT_GRAPH_OPS,
|
||||
LLAMA_EXAMPLE_DOWNLOAD,
|
||||
|
||||
LLAMA_EXAMPLE_COUNT,
|
||||
};
|
||||
@@ -290,13 +298,13 @@ struct common_params_sampling {
|
||||
};
|
||||
|
||||
struct common_params_model {
|
||||
std::string path = ""; // model local path // NOLINT
|
||||
std::string url = ""; // model url to download // NOLINT
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
std::string hf_file = ""; // HF file // NOLINT
|
||||
std::string docker_repo = ""; // Docker repo // NOLINT
|
||||
std::string path = ""; // model local path
|
||||
std::string url = ""; // model url to download
|
||||
std::string hf_repo = ""; // HF repo
|
||||
std::string hf_file = ""; // HF file
|
||||
std::string docker_repo = ""; // Docker repo
|
||||
|
||||
std::string get_name() {
|
||||
std::string get_name() const {
|
||||
if (!hf_repo.empty()) {
|
||||
return hf_repo;
|
||||
}
|
||||
@@ -305,6 +313,10 @@ struct common_params_model {
|
||||
}
|
||||
return path;
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
return get_name().empty();
|
||||
}
|
||||
};
|
||||
|
||||
// draft-model-based speculative decoding parameters
|
||||
@@ -367,7 +379,7 @@ struct common_params_speculative {
|
||||
common_params_speculative_ngram_cache ngram_cache;
|
||||
|
||||
bool has_dft() const {
|
||||
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
|
||||
return !draft.mparams.empty();
|
||||
}
|
||||
|
||||
uint32_t need_n_rs_seq() const {
|
||||
@@ -519,7 +531,6 @@ struct common_params {
|
||||
int32_t control_vector_layer_start = -1; // layer range for control vector
|
||||
int32_t control_vector_layer_end = -1; // layer range for control vector
|
||||
bool offline = false;
|
||||
bool skip_download = false; // skip model file downloading
|
||||
|
||||
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
||||
int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
||||
@@ -609,7 +620,7 @@ struct common_params {
|
||||
bool cache_prompt = true; // whether to enable prompt caching
|
||||
bool cache_idle_slots = true; // save and clear idle slots upon starting a new task
|
||||
int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot
|
||||
int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints
|
||||
int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints
|
||||
int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc.
|
||||
|
||||
std::string hostname = "127.0.0.1";
|
||||
|
||||
+22
-116
@@ -292,10 +292,6 @@ static int common_download_file_single_online(const std::string & url,
|
||||
|
||||
const bool file_exists = std::filesystem::exists(path);
|
||||
|
||||
if (!file_exists && opts.skip_download) {
|
||||
return -2; // file is missing and download is disabled
|
||||
}
|
||||
|
||||
if (file_exists && skip_etag) {
|
||||
LOG_DBG("%s: using cached file: %s\n", __func__, path.c_str());
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
@@ -362,9 +358,6 @@ static int common_download_file_single_online(const std::string & url,
|
||||
return 304; // 304 Not Modified - fake cached response
|
||||
}
|
||||
// pass this point, the file exists but is different from the server version, so we need to redownload it
|
||||
if (opts.skip_download) {
|
||||
return -2; // special code to indicate that the download was skipped due to etag mismatch
|
||||
}
|
||||
if (remove(path.c_str()) != 0) {
|
||||
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
||||
return -1;
|
||||
@@ -691,19 +684,8 @@ static void list_available_gguf_files(const hf_cache::hf_files & files) {
|
||||
}
|
||||
}
|
||||
|
||||
struct hf_plan {
|
||||
hf_cache::hf_file primary;
|
||||
hf_cache::hf_files model_files;
|
||||
hf_cache::hf_file mmproj;
|
||||
hf_cache::hf_file mtp;
|
||||
hf_cache::hf_file preset; // if set, only this file is downloaded
|
||||
};
|
||||
|
||||
static hf_plan get_hf_plan(const common_params_model & model,
|
||||
const common_download_opts & opts,
|
||||
bool download_mmproj,
|
||||
bool download_mtp) {
|
||||
hf_plan plan;
|
||||
common_download_hf_plan common_download_get_hf_plan(const common_params_model & model, const common_download_opts & opts) {
|
||||
common_download_hf_plan plan;
|
||||
hf_cache::hf_files all;
|
||||
|
||||
auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);
|
||||
@@ -752,125 +734,49 @@ static hf_plan get_hf_plan(const common_params_model & model,
|
||||
plan.primary = primary;
|
||||
plan.model_files = get_split_files(all, primary);
|
||||
|
||||
if (download_mmproj) {
|
||||
if (opts.download_mmproj) {
|
||||
plan.mmproj = find_best_mmproj(all, primary.path);
|
||||
}
|
||||
|
||||
if (download_mtp) {
|
||||
if (opts.download_mtp) {
|
||||
plan.mtp = find_best_mtp(all, primary.path);
|
||||
}
|
||||
|
||||
return plan;
|
||||
}
|
||||
|
||||
struct download_task {
|
||||
std::string url;
|
||||
std::string path;
|
||||
};
|
||||
|
||||
static std::vector<download_task> get_url_tasks(const common_params_model & model) {
|
||||
auto split = get_gguf_split_info(model.url);
|
||||
|
||||
if (split.count <= 1) {
|
||||
return {{model.url, model.path}};
|
||||
}
|
||||
|
||||
auto filename = split.prefix;
|
||||
if (auto pos = split.prefix.rfind('/'); pos != std::string::npos) {
|
||||
filename = split.prefix.substr(pos + 1);
|
||||
}
|
||||
|
||||
auto parent_path = std::filesystem::path(model.path).parent_path();
|
||||
auto prefix_path = (parent_path / filename).string();
|
||||
|
||||
std::vector<download_task> tasks;
|
||||
for (int i = 1; i <= split.count; i++) {
|
||||
auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count);
|
||||
tasks.push_back({split.prefix + suffix, prefix_path + suffix});
|
||||
}
|
||||
return tasks;
|
||||
}
|
||||
|
||||
common_download_model_result common_download_model(const common_params_model & model,
|
||||
const common_download_opts & opts) {
|
||||
common_download_model_result result;
|
||||
std::vector<download_task> tasks;
|
||||
hf_plan hf;
|
||||
|
||||
bool download_mmproj = opts.download_mmproj;
|
||||
bool download_mtp = opts.download_mtp;
|
||||
bool is_hf = !model.hf_repo.empty();
|
||||
|
||||
if (is_hf) {
|
||||
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
|
||||
if (!hf.preset.path.empty()) {
|
||||
// if preset.ini exists, only download that file alone
|
||||
tasks.push_back({hf.preset.url, hf.preset.local_path});
|
||||
} else {
|
||||
for (const auto & f : hf.model_files) {
|
||||
tasks.push_back({f.url, f.local_path});
|
||||
}
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
|
||||
}
|
||||
if (!hf.mtp.path.empty()) {
|
||||
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
|
||||
}
|
||||
}
|
||||
} else if (!model.url.empty()) {
|
||||
tasks = get_url_tasks(model);
|
||||
} else {
|
||||
result.model_path = model.path;
|
||||
return result;
|
||||
}
|
||||
|
||||
if (tasks.empty()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_download_run_tasks(const std::vector<common_download_task> & tasks) {
|
||||
std::vector<std::future<int>> futures;
|
||||
for (const auto & task : tasks) {
|
||||
futures.push_back(std::async(std::launch::async,
|
||||
[&task, &opts, is_hf]() {
|
||||
return common_download_file_single(task.url, task.path, opts, is_hf);
|
||||
[&task]() {
|
||||
return common_download_file_single(task.url, task.local_path, task.opts, task.is_hf);
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
for (auto & f : futures) {
|
||||
int status = f.get();
|
||||
if (status == -2 && opts.skip_download) {
|
||||
throw common_skip_download_exception();
|
||||
}
|
||||
for (size_t i = 0; i < futures.size(); ++i) {
|
||||
std::string url = tasks[i].url;
|
||||
int status = futures[i].get();
|
||||
bool is_ok = is_http_status_ok(status);
|
||||
if (!is_ok) {
|
||||
return {};
|
||||
throw std::runtime_error(string_format("Download '%s' failed with status code: %d", url.c_str(), status));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (is_hf) {
|
||||
if (!hf.preset.path.empty()) {
|
||||
// if preset.ini is used, do not set other paths
|
||||
result.preset_path = hf_cache::finalize_file(hf.preset);
|
||||
} else {
|
||||
for (const auto & f : hf.model_files) {
|
||||
hf_cache::finalize_file(f);
|
||||
}
|
||||
result.model_path = hf.primary.final_path;
|
||||
std::vector<std::string> common_download_get_all_parts(const std::string & url) {
|
||||
auto split = get_gguf_split_info(url);
|
||||
|
||||
if (!hf.mmproj.path.empty()) {
|
||||
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
|
||||
}
|
||||
|
||||
if (!hf.mtp.path.empty()) {
|
||||
result.mtp_path = hf_cache::finalize_file(hf.mtp);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.model_path = model.path;
|
||||
if (split.count <= 1) {
|
||||
return {url};
|
||||
}
|
||||
|
||||
return result;
|
||||
std::vector<std::string> parts;
|
||||
for (int i = 1; i <= split.count; i++) {
|
||||
auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count);
|
||||
parts.push_back(split.prefix + suffix);
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
+28
-42
@@ -1,7 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "hf-cache.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
struct common_params_model;
|
||||
|
||||
@@ -47,66 +50,40 @@ struct common_cached_model_info {
|
||||
}
|
||||
};
|
||||
|
||||
// Options for common_download_model and common_download_file_single
|
||||
// Options for common_download_file_single
|
||||
struct common_download_opts {
|
||||
std::string bearer_token;
|
||||
common_header_list headers;
|
||||
bool offline = false;
|
||||
bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid
|
||||
bool download_mmproj = false;
|
||||
bool download_mtp = false;
|
||||
common_download_callback * callback = nullptr;
|
||||
};
|
||||
|
||||
// Result of common_download_model
|
||||
struct common_download_model_result {
|
||||
std::string model_path;
|
||||
std::string mmproj_path;
|
||||
std::string mtp_path;
|
||||
std::string preset_path;
|
||||
struct common_download_task {
|
||||
common_download_opts opts;
|
||||
std::string url;
|
||||
std::string local_path;
|
||||
std::function<void()> on_done;
|
||||
bool is_hf = false;
|
||||
|
||||
common_download_task() = default;
|
||||
common_download_task(hf_cache::hf_file f,
|
||||
const common_download_opts & opts,
|
||||
std::function<void()> on_done = nullptr)
|
||||
: opts(opts), url(f.url), local_path(f.local_path), on_done(on_done), is_hf(true) {}
|
||||
};
|
||||
|
||||
// throw if the file is missing or invalid (e.g. ETag check failed)
|
||||
struct common_skip_download_exception : public std::runtime_error {
|
||||
common_skip_download_exception() : std::runtime_error("skip download") {}
|
||||
};
|
||||
void common_download_run_tasks(const std::vector<common_download_task> & tasks);
|
||||
|
||||
// Download model from HuggingFace repo or URL
|
||||
//
|
||||
// input (via model struct):
|
||||
// - model.hf_repo: HF repo with optional tag, see common_download_split_repo_tag
|
||||
// - model.hf_file: specific file in the repo (requires hf_repo)
|
||||
// - model.url: simple download (used if hf_repo is empty)
|
||||
// - model.path: local file path
|
||||
//
|
||||
// tag matching (for HF repos without model.hf_file):
|
||||
// - if tag is specified, searches for GGUF matching that quantization
|
||||
// - if no tag, searches for Q4_K_M, then Q4_0, then first available GGUF
|
||||
//
|
||||
// split GGUF: multi-part files like "model-00001-of-00003.gguf" are automatically
|
||||
// detected and all parts are downloaded
|
||||
//
|
||||
// caching:
|
||||
// - HF repos: uses HuggingFace cache
|
||||
// - URLs: uses ETag-based caching
|
||||
//
|
||||
// when opts.offline=true, no network requests are made
|
||||
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
|
||||
// then with the closest quantization bits
|
||||
// when download_mtp=true, applies the same sibling search for an MTP-head GGUF
|
||||
//
|
||||
// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure)
|
||||
common_download_model_result common_download_model(
|
||||
const common_params_model & model,
|
||||
const common_download_opts & opts = {}
|
||||
);
|
||||
// if url is a multi-part GGUF file, returns all parts, otherwise returns the single file
|
||||
std::vector<std::string> common_download_get_all_parts(const std::string & url);
|
||||
|
||||
// returns list of cached models
|
||||
std::vector<common_cached_model_info> common_list_cached_models();
|
||||
|
||||
// download single file from url to local path
|
||||
// returns status code or -1 on error
|
||||
// returns -2 if the download was skipped due to ETag mismatch (file outdated, skip_download=true)
|
||||
// skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash)
|
||||
int common_download_file_single(const std::string & url,
|
||||
const std::string & path,
|
||||
@@ -123,3 +100,12 @@ std::string common_docker_resolve_model(const std::string & docker);
|
||||
// - if tag is present, removes only files matching that tag (and orphaned blobs)
|
||||
// returns true if anything was removed
|
||||
bool common_download_remove(const std::string & hf_repo_with_tag);
|
||||
|
||||
struct common_download_hf_plan {
|
||||
hf_cache::hf_file primary;
|
||||
hf_cache::hf_files model_files;
|
||||
hf_cache::hf_file mmproj;
|
||||
hf_cache::hf_file mtp;
|
||||
hf_cache::hf_file preset; // if set, only this file is downloaded
|
||||
};
|
||||
common_download_hf_plan common_download_get_hf_plan(const common_params_model & model, const common_download_opts & opts);
|
||||
|
||||
+1
-1
@@ -233,7 +233,7 @@ static void common_params_fit_impl(
|
||||
sum_projected_used = dmds_full.back().mb.total();
|
||||
sum_free = dmds_full.back().total;
|
||||
sum_projected_free = sum_free - sum_projected_used;
|
||||
LOG_INF("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
|
||||
LOG_TRC("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
|
||||
__func__, sum_projected_used/MiB, sum_free/MiB);
|
||||
if (sum_projected_free >= margins[0]) {
|
||||
LOG_TRC("%s: will leave %" PRId64 " >= %" PRId64 " MiB of system memory, no changes needed\n",
|
||||
|
||||
@@ -1,324 +0,0 @@
|
||||
#include "json-partial.h"
|
||||
|
||||
#include "log.h"
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <regex>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
enum common_json_stack_element_type {
|
||||
COMMON_JSON_STACK_ELEMENT_OBJECT,
|
||||
COMMON_JSON_STACK_ELEMENT_KEY,
|
||||
COMMON_JSON_STACK_ELEMENT_ARRAY,
|
||||
};
|
||||
|
||||
struct common_json_stack_element {
|
||||
common_json_stack_element_type type;
|
||||
std::string key;
|
||||
};
|
||||
|
||||
bool common_json_parse(
|
||||
const std::string & input,
|
||||
const std::string & healing_marker,
|
||||
common_json & out)
|
||||
{
|
||||
std::string::const_iterator it = input.begin();
|
||||
const auto end = input.end();
|
||||
return common_json_parse(it, end, healing_marker, out);
|
||||
}
|
||||
|
||||
bool common_json_parse(
|
||||
std::string::const_iterator & it,
|
||||
const std::string::const_iterator & end,
|
||||
const std::string & healing_marker,
|
||||
common_json & out)
|
||||
{
|
||||
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||
std::size_t position;
|
||||
bool found_error;
|
||||
std::string last_token;
|
||||
std::string exception_message;
|
||||
std::vector<common_json_stack_element> stack;
|
||||
|
||||
json_error_locator() : position(0), found_error(false) {}
|
||||
|
||||
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
|
||||
this->position = position - 1;
|
||||
this->found_error = true;
|
||||
this->last_token = last_token;
|
||||
this->exception_message = ex.what();
|
||||
return false;
|
||||
}
|
||||
void close_value() {
|
||||
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
|
||||
stack.pop_back();
|
||||
}
|
||||
}
|
||||
bool null() override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool boolean(bool) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_integer(number_integer_t) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_unsigned(number_unsigned_t) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool number_float(number_float_t, const string_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool string(string_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool binary(binary_t &) override { // NOLINT
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool start_object(std::size_t) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
|
||||
return true;
|
||||
}
|
||||
bool end_object() override {
|
||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
|
||||
stack.pop_back();
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
bool key(string_t & key) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
|
||||
return true;
|
||||
}
|
||||
bool start_array(std::size_t) override { // NOLINT
|
||||
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
|
||||
return true;
|
||||
}
|
||||
bool end_array() override {
|
||||
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
|
||||
stack.pop_back();
|
||||
close_value();
|
||||
return true;
|
||||
}
|
||||
};
|
||||
json_error_locator err_loc;
|
||||
auto start = it;
|
||||
json::sax_parse(it, end, &err_loc);
|
||||
|
||||
if (err_loc.found_error) {
|
||||
it = start;
|
||||
auto temptative_end = it + err_loc.position;
|
||||
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
|
||||
|
||||
auto input = std::string(it, temptative_end);
|
||||
try {
|
||||
out.json = json::parse(input);
|
||||
// out.json = json::parse(it, temptative_end);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
} catch (const std::exception & ex) {
|
||||
// No, needs healing.
|
||||
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
|
||||
}
|
||||
auto can_parse = [](const std::string & str) {
|
||||
try {
|
||||
auto _ = json::parse(str); // NOLINT
|
||||
return true;
|
||||
} catch (const std::exception &) {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
if (!healing_marker.empty() && !err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
|
||||
if (last_non_sp_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||
}
|
||||
auto last_non_sp_char = str[last_non_sp_pos];
|
||||
// Used to detect stops on a number, which may not be complete.
|
||||
auto was_maybe_number = [&]() {
|
||||
if (!str.empty() && std::isspace(str.back())) {
|
||||
return false;
|
||||
}
|
||||
return std::isdigit(last_non_sp_char) ||
|
||||
last_non_sp_char == '.' ||
|
||||
last_non_sp_char == 'e' ||
|
||||
last_non_sp_char == 'E' ||
|
||||
last_non_sp_char == '-';
|
||||
};
|
||||
|
||||
std::string closing;
|
||||
for (size_t i = err_loc.stack.size(); i > 0; i--) {
|
||||
auto & el = err_loc.stack[i - 1];
|
||||
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||
closing += "}";
|
||||
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||
closing += "]";
|
||||
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
throw std::runtime_error("Unexpected stack element type");
|
||||
}
|
||||
}
|
||||
|
||||
// Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
|
||||
static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
|
||||
|
||||
auto is_high_surrogate = [&](const std::string & s) {
|
||||
// Check if a partial of a high surrogate (U+D800-U+DBFF)
|
||||
return s.length() >= 4 &&
|
||||
s[0] == '\\' && s[1] == 'u' &&
|
||||
std::tolower(s[2]) == 'd' &&
|
||||
(s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
|
||||
};
|
||||
|
||||
// Initialize the unicode marker to a low surrogate to handle the edge case
|
||||
// where a high surrogate (U+D800-U+DBFF) is immediately followed by a
|
||||
// backslash (\)
|
||||
std::string unicode_marker_padding = "udc00";
|
||||
std::smatch last_unicode_seq;
|
||||
|
||||
if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
|
||||
std::smatch second_last_seq;
|
||||
std::string prelude = str.substr(0, last_unicode_seq.position());
|
||||
|
||||
// Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
|
||||
unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
|
||||
|
||||
if (is_high_surrogate(last_unicode_seq.str())) {
|
||||
// If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
|
||||
unicode_marker_padding += "\\udc00";
|
||||
} else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
|
||||
if (is_high_surrogate(second_last_seq.str())) {
|
||||
// If this follows a high surrogate, pad it to be a low surrogate
|
||||
if (last_unicode_seq.length() == 2) {
|
||||
unicode_marker_padding = "dc00";
|
||||
} else if (last_unicode_seq.length() == 3) {
|
||||
unicode_marker_padding = "c00";
|
||||
} else {
|
||||
// The original unicode_marker_padding is already padded with 0s
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||
|
||||
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||
// We're inside an object value
|
||||
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
|
||||
// Was about to create an object value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + ": 1" + closing)) {
|
||||
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
|
||||
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
|
||||
// Was about to create an object
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + "\"" + closing)) {
|
||||
// Was inside an object value string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an object value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an object value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
// find last :
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||
}
|
||||
// Cutting back to opening : for object value
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
|
||||
// Was about to create an array value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + "\"" + closing)) {
|
||||
// Was inside an array value string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||
// Was inside an array value string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
|
||||
// Was inside an array value string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||
// Had just finished a value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of("[,");
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
|
||||
}
|
||||
// Cutting back to last [ or , for array value
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
|
||||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
|
||||
// Was about to create an object key+value
|
||||
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
|
||||
// Was about to create an object key+value
|
||||
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + "\": 1" + closing)) {
|
||||
// Was inside an object key string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||
// Was inside an object key string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||
} else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
|
||||
// Was inside an object key string after a partial unicode escape
|
||||
str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
|
||||
} else {
|
||||
auto last_pos = str.find_last_of(':');
|
||||
if (last_pos == std::string::npos) {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||
}
|
||||
// fprintf(stderr, "Cutting back to last : for object key+value\n");
|
||||
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||
}
|
||||
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
// handle unclosed top-level primitive
|
||||
if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
|
||||
std::string str(it, temptative_end);
|
||||
const auto & magic_seed = out.healing_marker.marker = healing_marker;
|
||||
if (can_parse(str + "\"")) {
|
||||
// Was inside an string
|
||||
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
|
||||
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
|
||||
// Was inside an string after an escape
|
||||
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
|
||||
} else {
|
||||
// TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||
// fprintf(stderr, "Closing: TODO\n");
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(str);
|
||||
it = temptative_end;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
out.json = json::parse(it, end);
|
||||
it = end;
|
||||
return true;
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: use json_fwd.hpp when possible
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||
struct common_healing_marker {
|
||||
// Raw marker.
|
||||
std::string marker;
|
||||
|
||||
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
|
||||
std::string json_dump_marker;
|
||||
};
|
||||
|
||||
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
|
||||
struct common_json {
|
||||
nlohmann::ordered_json json;
|
||||
|
||||
common_healing_marker healing_marker;
|
||||
};
|
||||
|
||||
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
|
||||
//
|
||||
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
|
||||
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
|
||||
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
|
||||
//
|
||||
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
|
||||
bool common_json_parse(
|
||||
const std::string & input,
|
||||
const std::string & healing_marker,
|
||||
common_json & out);
|
||||
|
||||
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
|
||||
bool common_json_parse(
|
||||
std::string::const_iterator & it,
|
||||
const std::string::const_iterator & end,
|
||||
const std::string & healing_marker,
|
||||
common_json & out);
|
||||
+102
-29
@@ -921,6 +921,10 @@ struct parser_executor {
|
||||
common_peg_parse_result operator()(const common_peg_gbnf_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
|
||||
common_peg_parse_result operator()(const common_peg_ac_parser & p) {
|
||||
return arena.parse(p.child, ctx, start_pos);
|
||||
}
|
||||
};
|
||||
|
||||
common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const {
|
||||
@@ -989,7 +993,8 @@ void common_peg_arena::resolve_refs() {
|
||||
std::is_same_v<T, common_peg_not_parser> ||
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_ac_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
p.child = resolve_ref(p.child);
|
||||
@@ -1070,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id
|
||||
return "Atomic(" + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_any_parser>) {
|
||||
return "Any";
|
||||
} else if constexpr (std::is_same_v<T, common_peg_space_parser>) {
|
||||
@@ -1479,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key
|
||||
});
|
||||
}
|
||||
|
||||
common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector<std::string> & delimiters) {
|
||||
if (delimiters.empty()) {
|
||||
throw std::runtime_error("ac parser requires at least one delimiter");
|
||||
}
|
||||
return add(common_peg_ac_parser{p, delimiters});
|
||||
}
|
||||
|
||||
static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
if (c == '-' || c == ']' || c == '[' || c == '\\') {
|
||||
return "\\" + std::string(1, (char) c);
|
||||
@@ -1529,14 +1543,22 @@ static std::string gbnf_escape_char_class(uint32_t c) {
|
||||
return std::string(buf);
|
||||
}
|
||||
|
||||
// GBNF grammar matching strings that contain no string in `strings` as a
|
||||
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
//
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
|
||||
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
static std::string gbnf_char_class(const std::vector<uint32_t> & chars, bool negate) {
|
||||
std::string s = negate ? "[^" : "[";
|
||||
for (uint32_t ch : chars) {
|
||||
s += gbnf_escape_char_class(ch);
|
||||
}
|
||||
return s + "]";
|
||||
}
|
||||
|
||||
static std::string gbnf_ac_grammar(
|
||||
const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings,
|
||||
const std::function<std::string(const std::vector<uint32_t> &,
|
||||
const std::map<size_t, std::vector<uint32_t>> &,
|
||||
const std::vector<uint32_t> &,
|
||||
const std::function<std::string(size_t)> &)> & build_rule) {
|
||||
aho_corasick ac(strings);
|
||||
|
||||
auto state_name = [&](size_t s) -> std::string {
|
||||
@@ -1548,42 +1570,30 @@ static std::string gbnf_excluding_grammar(const common_grammar_builder & builder
|
||||
return prefix + "-" + num;
|
||||
};
|
||||
|
||||
auto char_class = [](const std::vector<uint32_t> & chars, bool negate) {
|
||||
std::string s = negate ? "[^" : "[";
|
||||
for (uint32_t ch : chars) {
|
||||
s += gbnf_escape_char_class(ch);
|
||||
}
|
||||
return s + "]";
|
||||
};
|
||||
|
||||
for (size_t q = 0; q < ac.num_states(); q++) {
|
||||
if (ac.is_terminal(q)) {
|
||||
continue; // match states are dropped
|
||||
continue; // match states
|
||||
}
|
||||
|
||||
std::map<size_t, std::vector<uint32_t>> buckets;
|
||||
std::vector<uint32_t> excluded;
|
||||
std::vector<uint32_t> completing; // chars that complete a delimiter
|
||||
std::vector<uint32_t> specific; // chars with an explicit transition
|
||||
for (uint32_t c : ac.alphabet) {
|
||||
size_t d = ac.next(q, c);
|
||||
if (ac.is_terminal(d)) {
|
||||
excluded.push_back(c); // completes a forbidden string -> omit
|
||||
completing.push_back(c);
|
||||
specific.push_back(c);
|
||||
} else if (d != 0) {
|
||||
buckets[d].push_back(c); // specific non-root destination
|
||||
excluded.push_back(c);
|
||||
specific.push_back(c);
|
||||
}
|
||||
}
|
||||
|
||||
std::string rhs = "|"; // every state is accepting
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
rhs += " " + char_class(chars, false) + " " + state_name(d) + " |";
|
||||
}
|
||||
rhs += " " + char_class(excluded, true) + " " + state_name(0);
|
||||
|
||||
builder.add_rule(state_name(q), rhs);
|
||||
builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name));
|
||||
}
|
||||
|
||||
// An empty delimiter makes the start state terminal. Emit an entry rule
|
||||
// that matches nothing so the returned reference stays valid.
|
||||
// that matches the empty string so the returned reference stays valid.
|
||||
if (ac.is_terminal(0)) {
|
||||
builder.add_rule(prefix, "|");
|
||||
}
|
||||
@@ -1591,6 +1601,54 @@ static std::string gbnf_excluding_grammar(const common_grammar_builder & builder
|
||||
return state_name(0);
|
||||
}
|
||||
|
||||
// GBNF grammar matching strings that contain no string in `strings` as a
|
||||
// substring. Emits the complement of an Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
//
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/24839
|
||||
static std::string gbnf_excluding_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
return gbnf_ac_grammar(builder, prefix, strings,
|
||||
[](const std::vector<uint32_t> & /*completing*/,
|
||||
const std::map<size_t, std::vector<uint32_t>> & buckets,
|
||||
const std::vector<uint32_t> & specific,
|
||||
const std::function<std::string(size_t)> & state_name) {
|
||||
// every state is accepting and completing chars get no
|
||||
// alternative, so a forbidden string can never be matched
|
||||
std::string rhs = "|";
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |";
|
||||
}
|
||||
rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0);
|
||||
return rhs;
|
||||
});
|
||||
}
|
||||
|
||||
// GBNF grammar matching everything up to and including the first occurrence of
|
||||
// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns
|
||||
// the start state rule name.
|
||||
static std::string gbnf_including_grammar(const common_grammar_builder & builder,
|
||||
const std::string & prefix,
|
||||
const std::vector<std::string> & strings) {
|
||||
return gbnf_ac_grammar(builder, prefix, strings,
|
||||
[](const std::vector<uint32_t> & completing,
|
||||
const std::map<size_t, std::vector<uint32_t>> & buckets,
|
||||
const std::vector<uint32_t> & specific,
|
||||
const std::function<std::string(size_t)> & state_name) {
|
||||
std::vector<std::string> alts;
|
||||
if (!completing.empty()) {
|
||||
alts.push_back(gbnf_char_class(completing, false)); // terminate on match
|
||||
}
|
||||
for (const auto & [d, chars] : buckets) {
|
||||
alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d));
|
||||
}
|
||||
// every other character keeps scanning from the start state
|
||||
alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0));
|
||||
return string_join(alts, " | ");
|
||||
});
|
||||
}
|
||||
|
||||
static std::set<std::string> collect_reachable_rules(
|
||||
const common_peg_arena & arena,
|
||||
const common_peg_parser_id & rule
|
||||
@@ -1628,6 +1686,7 @@ static std::set<std::string> collect_reachable_rules(
|
||||
std::is_same_v<T, common_peg_tag_parser> ||
|
||||
std::is_same_v<T, common_peg_atomic_parser> ||
|
||||
std::is_same_v<T, common_peg_gbnf_parser> ||
|
||||
std::is_same_v<T, common_peg_ac_parser> ||
|
||||
std::is_same_v<T, common_peg_schema_parser>) {
|
||||
visit(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_rule_parser>) {
|
||||
@@ -1822,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo
|
||||
return to_gbnf(p.child);
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return p.grammar;
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters);
|
||||
} else {
|
||||
static_assert(is_always_false_v<T>);
|
||||
}
|
||||
@@ -1958,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant &
|
||||
};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_gbnf_parser>) {
|
||||
return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}};
|
||||
} else if constexpr (std::is_same_v<T, common_peg_ac_parser>) {
|
||||
return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}};
|
||||
}
|
||||
}, variant);
|
||||
}
|
||||
@@ -2130,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json
|
||||
};
|
||||
}
|
||||
|
||||
if (type == "ac") {
|
||||
if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) {
|
||||
throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array");
|
||||
}
|
||||
return common_peg_ac_parser{
|
||||
j["child"].get<common_peg_parser_id>(),
|
||||
j["delimiters"].get<std::vector<std::string>>(),
|
||||
};
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unknown parser type: " + type);
|
||||
}
|
||||
|
||||
|
||||
+14
-1
@@ -275,6 +275,11 @@ struct common_peg_gbnf_parser {
|
||||
std::string grammar;
|
||||
};
|
||||
|
||||
struct common_peg_ac_parser {
|
||||
common_peg_parser_id child;
|
||||
std::vector<std::string> delimiters;
|
||||
};
|
||||
|
||||
// Variant holding all parser types
|
||||
using common_peg_parser_variant = std::variant<
|
||||
common_peg_epsilon_parser,
|
||||
@@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant<
|
||||
common_peg_ref_parser,
|
||||
common_peg_atomic_parser,
|
||||
common_peg_tag_parser,
|
||||
common_peg_gbnf_parser
|
||||
common_peg_gbnf_parser,
|
||||
common_peg_ac_parser
|
||||
>;
|
||||
|
||||
class common_peg_arena {
|
||||
@@ -514,6 +520,13 @@ class common_peg_parser_builder {
|
||||
// the child's grammar. Parsing delegates entirely to the child.
|
||||
common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); }
|
||||
|
||||
// Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick
|
||||
// automaton of `delimiters`, matching everything up to and including the
|
||||
// first delimiter. Parsing delegates entirely to the child, which is
|
||||
// responsible for consuming the delimiter (e.g. until(D) + literal(D)).
|
||||
common_peg_parser ac(const common_peg_parser & p, const std::vector<std::string> & delimiters);
|
||||
common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector<std::string>{delimiter}); }
|
||||
|
||||
void set_root(const common_peg_parser & p);
|
||||
|
||||
common_peg_arena build();
|
||||
|
||||
+10
-10
@@ -65,12 +65,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
if (ctx->start_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_COUNTING;
|
||||
ctx->remaining = ctx->budget;
|
||||
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
|
||||
COM_TRC("activated, budget=%d tokens\n", ctx->budget);
|
||||
|
||||
if (ctx->remaining <= 0) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
||||
COM_TRC("%s", "budget=0, forcing immediately\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -80,7 +80,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
{
|
||||
if (ctx->end_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: deactivated (natural end)\n");
|
||||
COM_TRC("%s", "deactivated (natural end)\n");
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
|
||||
COM_TRC("%s", "UTF-8 complete, now forcing end sequence\n");
|
||||
}
|
||||
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
|
||||
ctx->remaining--;
|
||||
@@ -104,11 +104,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
|
||||
COM_TRC("%s", "budget exhausted, forcing end sequence\n");
|
||||
} else {
|
||||
ctx->state = REASONING_BUDGET_WAITING_UTF8;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
|
||||
COM_TRC("%s", "budget exhausted, waiting for UTF-8 completion\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -118,7 +118,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->force_pos++;
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: forced sequence complete, done\n");
|
||||
COM_TRC("%s", "forced sequence complete, done\n");
|
||||
}
|
||||
break;
|
||||
case REASONING_BUDGET_DONE:
|
||||
@@ -128,12 +128,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->state = REASONING_BUDGET_COUNTING;
|
||||
ctx->remaining = ctx->budget;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
|
||||
COM_TRC("re-activated on new start tag, budget=%d tokens\n", ctx->budget);
|
||||
|
||||
if (ctx->remaining <= 0) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
||||
COM_TRC("%s", "budget=0, forcing immediately\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -264,7 +264,7 @@ bool common_reasoning_budget_force(struct llama_sampler * smpl) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n");
|
||||
COM_TRC("%s", "forced into forcing state (manual transition)\n");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
+69
-64
@@ -18,6 +18,13 @@
|
||||
#include <map>
|
||||
#include <cinttypes>
|
||||
|
||||
#define SPC_DBG(fmt, ...) LOG_DBG("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_TRC(fmt, ...) LOG_TRC("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_INF(fmt, ...) LOG_INF("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_WRN(fmt, ...) LOG_WRN("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_ERR(fmt, ...) LOG_ERR("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
@@ -60,21 +67,20 @@ static bool common_speculative_are_compatible(
|
||||
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
|
||||
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
|
||||
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
|
||||
SPC_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
|
||||
|
||||
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
|
||||
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
||||
SPC_DBG("vocab_type dft: %d\n", vocab_type_dft);
|
||||
|
||||
if (vocab_type_tgt != vocab_type_dft) {
|
||||
LOG_WRN("%s: draft model vocab type must match target model to use speculation but "
|
||||
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
|
||||
SPC_WRN("draft model vocab type must match target model to use speculation but "
|
||||
"vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
|
||||
LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
SPC_WRN("draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
|
||||
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
|
||||
return false;
|
||||
@@ -82,8 +88,7 @@ static bool common_speculative_are_compatible(
|
||||
|
||||
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
|
||||
LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
SPC_WRN("draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
|
||||
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
|
||||
return false;
|
||||
@@ -97,8 +102,8 @@ static bool common_speculative_are_compatible(
|
||||
: n_vocab_dft - n_vocab_tgt;
|
||||
|
||||
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
||||
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||
SPC_DBG("draft model vocab must closely match target model to use speculation but "
|
||||
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
||||
return false;
|
||||
}
|
||||
@@ -108,8 +113,8 @@ static bool common_speculative_are_compatible(
|
||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
|
||||
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
|
||||
SPC_DBG("draft model vocab must match target model to use speculation but "
|
||||
"token %d content differs - target '%s', draft '%s'\n", i,
|
||||
common_token_to_piece(vocab_tgt, i).c_str(),
|
||||
common_token_to_piece(vocab_dft, i).c_str());
|
||||
return false;
|
||||
@@ -186,9 +191,9 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
|
||||
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'draft-simple'\n");
|
||||
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f\n", this->params.n_max, this->params.n_min, this->params.p_min);
|
||||
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
|
||||
this->params.n_gpu_layers,
|
||||
ggml_type_name(this->params.cache_type_k),
|
||||
ggml_type_name(this->params.cache_type_v),
|
||||
@@ -228,16 +233,16 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
}
|
||||
|
||||
const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
|
||||
LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt);
|
||||
SPC_DBG("vocab_cmpt = %d\n", vocab_cmpt);
|
||||
|
||||
if (!vocab_cmpt) {
|
||||
LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__);
|
||||
SPC_ERR("%s", "the target and draft vocabs are not compatible\n");
|
||||
|
||||
throw std::runtime_error("draft model vocab type must match target model to use speculation");
|
||||
}
|
||||
|
||||
if (n_seq != llama_n_seq_max(ctx_dft)) {
|
||||
LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft));
|
||||
SPC_ERR("n_seq mismatch: %d != %d\n", n_seq, llama_n_seq_max(ctx_dft));
|
||||
|
||||
throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq");
|
||||
}
|
||||
@@ -257,7 +262,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
const int ret = llama_decode(ctx_dft, batch);
|
||||
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret);
|
||||
SPC_ERR("failed to decode draft batch, ret = %d\n", ret);
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -290,7 +295,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
SPC_ERR("llama_decode returned %d\n", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -314,7 +319,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
@@ -354,7 +359,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
// evaluate the drafted tokens on the draft model
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -449,8 +454,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
|
||||
, params(params.draft)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
|
||||
SPC_TRC("%s", "adding speculative implementation 'draft-eagle3'\n");
|
||||
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
|
||||
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
@@ -493,7 +498,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
|
||||
|
||||
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
|
||||
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
|
||||
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
|
||||
llama_sampler_free(chain);
|
||||
chain = nullptr;
|
||||
}
|
||||
@@ -548,9 +553,9 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_max < N - 2) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
|
||||
SPC_WRN("ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
|
||||
"Drafts may degrade.\n",
|
||||
__func__, (int) pos_max, N - 2);
|
||||
(int) pos_max, N - 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -621,8 +626,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
};
|
||||
const int32_t rc = llama_encode(ctx_dft, enc_batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
|
||||
__func__, rc, (int) n_chunk, (int) i);
|
||||
SPC_ERR("llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
|
||||
rc, (int) n_chunk, (int) i);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -692,8 +697,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
if (batch.n_tokens > 0) {
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
|
||||
__func__, rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
|
||||
SPC_ERR("llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
|
||||
rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -744,7 +749,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
SPC_ERR("llama_decode returned %d\n", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -770,7 +775,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
@@ -809,7 +814,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -942,9 +947,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
"MTP input row width must match the target h_nextn width");
|
||||
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'draft-mtp'\n");
|
||||
SPC_TRC("- n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
|
||||
this->params.n_gpu_layers,
|
||||
ggml_type_name(this->params.cache_type_k),
|
||||
ggml_type_name(this->params.cache_type_v),
|
||||
@@ -975,7 +980,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
|
||||
|
||||
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
|
||||
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
|
||||
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
|
||||
llama_sampler_free(chain);
|
||||
chain = nullptr;
|
||||
}
|
||||
@@ -1038,11 +1043,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
|
||||
if (pos_max < N - 1 && !is_mem_shared) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
|
||||
SPC_WRN("ctx_dft pos_max=%d < N-1=%d - "
|
||||
"process() hook may not have run on every prefill ubatch "
|
||||
"(need_embd / logits=1 on every prompt position?). "
|
||||
"Drafts may degrade.\n",
|
||||
__func__, (int) pos_max, N - 1);
|
||||
(int) pos_max, N - 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1128,8 +1133,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
|
||||
__func__, head, (int) rc, (int) batch_in.pos[0]);
|
||||
SPC_ERR("llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
|
||||
head, (int) rc, (int) batch_in.pos[0]);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
@@ -1217,7 +1222,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1239,7 +1244,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
@@ -1353,8 +1358,8 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
|
||||
, params(params.ngram_simple)
|
||||
, config(config)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__);
|
||||
LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'ngram-simple'\n");
|
||||
SPC_TRC("- size_n=%d, size_m=%d, min_hits=%d\n",
|
||||
this->params.size_n, this->params.size_m, this->params.min_hits);
|
||||
}
|
||||
|
||||
@@ -1403,8 +1408,8 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
|
||||
this->config.push_back(config);
|
||||
}
|
||||
|
||||
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str());
|
||||
LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__,
|
||||
SPC_TRC("adding speculative implementation '%s'\n", common_speculative_type_to_str(this->type).c_str());
|
||||
SPC_TRC("- size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n",
|
||||
config.size_key, config.size_value, config.key_only, config.min_hits);
|
||||
}
|
||||
|
||||
@@ -1478,15 +1483,15 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
|
||||
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__);
|
||||
LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'ngram-mod'\n");
|
||||
SPC_TRC("- n_match=%d, n_max=%d, n_min=%d\n",
|
||||
this->params.n_match, this->params.n_max, this->params.n_min);
|
||||
LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__,
|
||||
SPC_TRC("- mod size=%zu (%.3f MB)\n",
|
||||
mod.size(), (float)(mod.size_bytes())/1024/1024);
|
||||
|
||||
if (this->params.n_match < 16) {
|
||||
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
|
||||
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match);
|
||||
SPC_WRN("ngram_mod n_match=%d is too small - poor quality is possible, "
|
||||
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", this->params.n_match);
|
||||
}
|
||||
|
||||
sinfos.resize(n_seq);
|
||||
@@ -1510,11 +1515,11 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
sinfo.i_last = prompt.size() - n;
|
||||
|
||||
const double f = (double)mod.get_used() / (double)mod.size();
|
||||
LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f);
|
||||
SPC_TRC("ngram_mod occupancy = %zu/%zu (%.2f)\n", mod.get_used(), mod.size(), f);
|
||||
|
||||
constexpr double f_thold = 0.25;
|
||||
if (f > f_thold) {
|
||||
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
|
||||
SPC_WRN("ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", f, f_thold);
|
||||
|
||||
mod.reset();
|
||||
}
|
||||
@@ -1608,7 +1613,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
sinfo.n_low++;
|
||||
if (sinfo.n_low >= 5) {
|
||||
if (verbose) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low);
|
||||
SPC_TRC("low acceptance streak (%d) - resetting ngram_mod\n", sinfo.n_low);
|
||||
}
|
||||
|
||||
mod.reset();
|
||||
@@ -1658,8 +1663,8 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
, save_dynamic(save_dynamic)
|
||||
, save_static(save_static)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__);
|
||||
LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'ngram-cache'\n");
|
||||
SPC_TRC("- n_draft=%d, cache_static=%s, cache_dynamic=%s\n",
|
||||
n_draft,
|
||||
path_static.empty() ? "none" : path_static.c_str(),
|
||||
path_dynamic.empty() ? "none" : path_dynamic.c_str());
|
||||
@@ -1674,7 +1679,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
sinfo.ngram_cache_static = ngram_cache_static;
|
||||
}
|
||||
} catch (...) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
|
||||
SPC_ERR("failed to open static lookup cache: %s", path_static.c_str());
|
||||
GGML_ABORT("Couldn't read static lookup cache");
|
||||
}
|
||||
}
|
||||
@@ -1687,7 +1692,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
sinfo.ngram_cache_dynamic = ngram_cache_dynamic;
|
||||
}
|
||||
} catch (...) {
|
||||
LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
|
||||
SPC_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
|
||||
GGML_ABORT("Couldn't read dynamic lookup cache");
|
||||
}
|
||||
}
|
||||
@@ -2034,7 +2039,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
}
|
||||
|
||||
if (impls.empty()) {
|
||||
LOG_WRN("%s: no implementations specified for speculative decoding\n", __func__);
|
||||
SPC_TRC("%s", "no implementations specified for speculative decoding\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -2161,13 +2166,13 @@ void common_speculative_draft(common_speculative * spec) {
|
||||
|
||||
if (dp.n_max > 0) {
|
||||
if (!result.empty() && (int) result.size() > dp.n_max) {
|
||||
LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max);
|
||||
SPC_DBG("truncating draft to %d tokens\n", dp.n_max);
|
||||
result.resize(dp.n_max);
|
||||
}
|
||||
}
|
||||
|
||||
if (!result.empty()) {
|
||||
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
|
||||
SPC_DBG("called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n",
|
||||
common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(),
|
||||
impl.get()->n_call_draft, result.size());
|
||||
|
||||
@@ -2291,7 +2296,7 @@ void common_speculative_print_stats(const common_speculative * spec) {
|
||||
str_stats = ", #mean acc len = " + oss.str() + ", #acc rate/pos = (" + tmp.str() + ")";
|
||||
}
|
||||
|
||||
LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
|
||||
SPC_TRC("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
|
||||
common_speculative_type_to_str(impl->type).c_str(),
|
||||
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
|
||||
impl->n_gen_drafts,
|
||||
|
||||
@@ -46,6 +46,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"DbrxForCausalLM": "dbrx",
|
||||
"DeciLMForCausalLM": "deci",
|
||||
"DeepseekForCausalLM": "deepseek",
|
||||
"DeepseekOCRForCausalLM": "deepseek",
|
||||
"DeepseekV2ForCausalLM": "deepseek",
|
||||
"DeepseekV3ForCausalLM": "deepseek",
|
||||
"DeepseekV32ForCausalLM": "deepseek",
|
||||
@@ -96,6 +97,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"GraniteMoeHybridForCausalLM": "granite",
|
||||
"GraniteMoeSharedForCausalLM": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"Grok1ForCausalLM": "grok",
|
||||
"GrokForCausalLM": "grok",
|
||||
"GroveMoeForCausalLM": "grovemoe",
|
||||
@@ -123,6 +125,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"LLaDAModelLM": "llada",
|
||||
"LLaMAForCausalLM": "llama",
|
||||
"Lfm25AudioTokenizer": "lfm2",
|
||||
"Lfm2BidirectionalModel": "lfm2",
|
||||
"Lfm2ForCausalLM": "lfm2",
|
||||
"Lfm2Model": "lfm2",
|
||||
"Lfm2MoeForCausalLM": "lfm2",
|
||||
@@ -133,6 +136,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"LlamaModel": "llama",
|
||||
"Eagle3DraftModel": "llama",
|
||||
"Eagle3Speculator": "llama",
|
||||
"Eagle3LlamaForCausalLM": "llama",
|
||||
"LlamaForCausalLMEagle3": "llama",
|
||||
"LlavaForConditionalGeneration": "llama",
|
||||
"LlavaStableLMEpochForCausalLM": "stablelm",
|
||||
@@ -231,6 +235,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"UMT5ForConditionalGeneration": "t5",
|
||||
"UMT5Model": "t5",
|
||||
"UltravoxModel": "ultravox",
|
||||
"UnlimitedOCRForCausalLM": "deepseek",
|
||||
"VLlama3ForCausalLM": "llama",
|
||||
"VoxtralForConditionalGeneration": "llama",
|
||||
"WavTokenizerDec": "wavtokenizer",
|
||||
@@ -261,6 +266,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
||||
"GlmasrModel": "ultravox",
|
||||
"Granite4VisionForConditionalGeneration": "granite",
|
||||
"GraniteSpeechForConditionalGeneration": "granite",
|
||||
"GraniteSpeechPlusForConditionalGeneration": "granite",
|
||||
"HunYuanVLForConditionalGeneration": "hunyuan",
|
||||
"Idefics3ForConditionalGeneration": "smolvlm",
|
||||
"InternVisionModel": "internvl",
|
||||
@@ -296,6 +302,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = {
|
||||
"StepVLForConditionalGeneration": "step3",
|
||||
"Step3p7ForConditionalGeneration": "step3",
|
||||
"UltravoxModel": "ultravox",
|
||||
"UnlimitedOCRForCausalLM": "deepseek",
|
||||
"VoxtralForConditionalGeneration": "ultravox",
|
||||
"YoutuVLForConditionalGeneration": "youtuvl",
|
||||
}
|
||||
|
||||
+10
-2
@@ -14,7 +14,7 @@ from .base import MmprojModel, ModelBase, TextModel, gguf, logger
|
||||
from .qwen import QwenModel
|
||||
|
||||
|
||||
@ModelBase.register("DeepseekOCRForCausalLM")
|
||||
@ModelBase.register("DeepseekOCRForCausalLM", "UnlimitedOCRForCausalLM")
|
||||
class DeepseekOCRVisionModel(MmprojModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -205,6 +205,8 @@ class DeepseekModel(TextModel):
|
||||
@ModelBase.register(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekOCRForCausalLM",
|
||||
"UnlimitedOCRForCausalLM",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"KimiK25ForConditionalGeneration",
|
||||
"YoutuForCausalLM",
|
||||
@@ -224,7 +226,7 @@ class DeepseekV2Model(TextModel):
|
||||
self.origin_hf_arch = hparams.get('architectures', [None])[0]
|
||||
|
||||
# special handling for Deepseek OCR
|
||||
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM"):
|
||||
if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM", "UnlimitedOCRForCausalLM"):
|
||||
self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR
|
||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
self.gguf_writer.add_architecture()
|
||||
@@ -350,6 +352,12 @@ class DeepseekV2Model(TextModel):
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
# Unlimited-OCR sliding window; written for metadata, the decoder ignores it (full MHA)
|
||||
if is_ocr:
|
||||
sliding_window = hparams.get("sliding_window_size") or hparams.get("sliding_window")
|
||||
if sliding_window:
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
|
||||
if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None:
|
||||
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
|
||||
|
||||
@@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("GraniteSpeechPlusForConditionalGeneration")
|
||||
class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel):
|
||||
"""Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation"""
|
||||
has_vision_encoder = False
|
||||
has_audio_encoder = True
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
assert self.hparams_audio is not None
|
||||
super().set_gguf_parameters()
|
||||
|
||||
# Add feature_layer if present in encoder config
|
||||
if feature_layers := self.hparams_audio.get("cat_hidden_layers"):
|
||||
self.gguf_writer.add_audio_feature_layers(feature_layers)
|
||||
logger.info(f"gguf: audio feature_layers = {feature_layers}")
|
||||
|
||||
# Validate projector dimension matches concatenated encoder output
|
||||
hidden_dim = self.hparams_audio["hidden_dim"]
|
||||
expected_dim = hidden_dim * (len(feature_layers) + 1)
|
||||
projector_dim = self.global_config["projector_config"]["encoder_hidden_size"]
|
||||
|
||||
if projector_dim != expected_dim:
|
||||
raise ValueError(
|
||||
f"Projector encoder_hidden_size ({projector_dim}) does not match "
|
||||
f"expected concatenated dimension ({expected_dim}). "
|
||||
f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}"
|
||||
)
|
||||
|
||||
|
||||
@ModelBase.register("Granite4VisionForConditionalGeneration")
|
||||
class Granite4VisionMmprojModel(MmprojModel):
|
||||
has_vision_encoder = True
|
||||
|
||||
+10
-3
@@ -64,11 +64,17 @@ class LFM2Model(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2Model")
|
||||
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
|
||||
class LFM2ColBertModel(LFM2Model):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
dense_tensor_name = "dense_2"
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if self.hf_arch == "Lfm2BidirectionalModel":
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self._try_set_pooling_type()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if not name.startswith(self.dense_tensor_name):
|
||||
name = "model." + name
|
||||
@@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# dense tensor is stored in a separate safetensors file
|
||||
# optional dense tensor is stored in a separate safetensors file
|
||||
from safetensors.torch import load_file
|
||||
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
|
||||
assert tensors_file.is_file()
|
||||
if not tensors_file.is_file():
|
||||
return
|
||||
tensor = load_file(tensors_file)["linear.weight"]
|
||||
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
|
||||
yield f"{self.dense_tensor_name}.weight", tensor.clone()
|
||||
|
||||
@@ -23,6 +23,7 @@ from .base import ModelBase, TextModel, gguf, logger
|
||||
"LlavaForConditionalGeneration",
|
||||
"VoxtralForConditionalGeneration",
|
||||
"LlamaForCausalLMEagle3",
|
||||
"Eagle3LlamaForCausalLM",
|
||||
"Eagle3Speculator",
|
||||
"Eagle3DraftModel",
|
||||
"IQuestCoderForCausalLM",
|
||||
|
||||
+3
-4
@@ -114,7 +114,8 @@ class Mamba2Model(TextModel):
|
||||
hparams["text_config"] = hparams["llm_config"]
|
||||
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
|
||||
self.expand = self.find_hparam(["mamba_expand", "expand"], optional=True) or 2
|
||||
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or self.expand * self.d_model
|
||||
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1
|
||||
|
||||
def set_vocab(self):
|
||||
@@ -144,11 +145,9 @@ class Mamba2Model(TextModel):
|
||||
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||
|
||||
# Fail early for models which don't have a block expansion factor of 2
|
||||
# TODO: does this really matter?
|
||||
# skip the assertion for FalconH1 Model
|
||||
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
|
||||
assert self.d_inner == 2 * self.d_model
|
||||
assert self.d_inner == self.expand * self.d_model
|
||||
assert self.d_inner % head_dim == 0
|
||||
|
||||
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
|
||||
|
||||
+1
-1
@@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu
|
||||
|
||||
```
|
||||
$ apt update && apt upgrade -y
|
||||
$ apt install git cmake
|
||||
$ apt install git cmake libandroid-spawn
|
||||
```
|
||||
|
||||
Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake.
|
||||
|
||||
@@ -237,8 +237,8 @@ chmod +x ubuntu-llamacpp-ov-install.sh
|
||||
# ============================================
|
||||
set -euo pipefail
|
||||
|
||||
OPENVINO_VERSION_MAJOR="2026.2"
|
||||
OPENVINO_VERSION_FULL="2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR="2026.2.1"
|
||||
OPENVINO_VERSION_FULL="2026.2.1.21919.ede283a88e3"
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OPENVINO_INSTALL_DIR="/opt/intel/openvino_${OPENVINO_VERSION_MAJOR}"
|
||||
@@ -334,7 +334,7 @@ echo " ./build/ReleaseOV/bin/llama-cli -m model.gguf"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The script pins OpenVINO `2026.2` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release.
|
||||
> The script pins OpenVINO `2026.2.1` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -364,8 +364,8 @@ REM ============================================
|
||||
REM llama.cpp OpenVINO Build Script (Ninja)
|
||||
REM ============================================
|
||||
|
||||
set "OPENVINO_VERSION_MAJOR=2026.2"
|
||||
set "OPENVINO_VERSION_FULL=2026.2.0.21903.52ddc073857"
|
||||
set "OPENVINO_VERSION_MAJOR=2026.2.1"
|
||||
set "OPENVINO_VERSION_FULL=2026.2.1.21919.ede283a88e3"
|
||||
|
||||
set "SCRIPT_DIR=%~dp0"
|
||||
set "VCPKG_DIR=C:\vcpkg"
|
||||
@@ -547,7 +547,7 @@ endlocal
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The script pins OpenVINO `2026.2` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release. From any new shell, source the matching `setupvars` script via the junction — `call "C:\Intel\openvino\setupvars.bat"` from `cmd`, or `& "C:\Intel\openvino\setupvars.ps1"` from PowerShell. If `winget` cannot register Visual Studio Build Tools on first run, install them once manually and re-run the script from an elevated **Developer Command Prompt for VS 2022**.
|
||||
> The script pins OpenVINO `2026.2.1` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release. From any new shell, source the matching `setupvars` script via the junction — `call "C:\Intel\openvino\setupvars.bat"` from `cmd`, or `& "C:\Intel\openvino\setupvars.ps1"` from PowerShell. If `winget` cannot register Visual Studio Build Tools on first run, install them once manually and re-run the script from an elevated **Developer Command Prompt for VS 2022**.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@@ -413,6 +413,15 @@ In two device selection modes, the default SYCL backend is level_zero, you can c
|
||||
|------------------|----------------------------------------|
|
||||
| Single device | --split-mode none --main-gpu DEVICE_ID |
|
||||
| Multiple devices | --split-mode layer (default) |
|
||||
| Multiple devices | --split-mode tensor (tensor parallelism) |
|
||||
|
||||
`--split-mode tensor` (tensor parallelism) shards each layer across the selected
|
||||
GPUs. It requires flash attention, which is auto-enabled when `--flash-attn` is
|
||||
left at its default `auto`, so `--split-mode tensor` works out of the box.
|
||||
Passing `--flash-attn off` together with `--split-mode tensor` is rejected at
|
||||
context creation. The default `f16` KV cache is recommended. Tensor parallelism
|
||||
is currently optimized for 2 GPUs; other device counts fall back to a generic
|
||||
all-reduce.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -715,6 +724,15 @@ In two device selection modes, the default SYCL backend is level_zero, you can c
|
||||
|------------------|----------------------------------------|
|
||||
| Single device | --split-mode none --main-gpu DEVICE_ID |
|
||||
| Multiple devices | --split-mode layer (default) |
|
||||
| Multiple devices | --split-mode tensor (tensor parallelism) |
|
||||
|
||||
`--split-mode tensor` (tensor parallelism) shards each layer across the selected
|
||||
GPUs. It requires flash attention, which is auto-enabled when `--flash-attn` is
|
||||
left at its default `auto`, so `--split-mode tensor` works out of the box.
|
||||
Passing `--flash-attn off` together with `--split-mode tensor` is rejected at
|
||||
context creation. The default `f16` KV cache is recommended. Tensor parallelism
|
||||
is currently optimized for 2 GPUs; other device counts fall back to a generic
|
||||
all-reduce.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "ON",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
@@ -47,7 +46,6 @@
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "ON",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
@@ -73,7 +71,6 @@
|
||||
"GGML_LLAMAFILE": "OFF",
|
||||
"GGML_OPENCL": "OFF",
|
||||
"GGML_HEXAGON": "ON",
|
||||
"GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128",
|
||||
"LLAMA_OPENSSL": "OFF"
|
||||
}
|
||||
},
|
||||
|
||||
+41
-1
@@ -13,6 +13,45 @@ The `llama-server` application supports several implementations of speculative d
|
||||
A much smaller model (called the _draft model_) generates drafts.
|
||||
A draft model is the most used approach in speculative decoding.
|
||||
|
||||
### EAGLE-3 (`draft-eagle3`)
|
||||
|
||||
EAGLE-3 uses a small draft model that reads the target model's hidden states to predict the next tokens, so it
|
||||
reaches higher acceptance than a standalone draft model of the same size. The draft is a one-layer transformer
|
||||
trained for a specific target model; it shares the target model's tokenizer and, optionally, uses a reduced draft
|
||||
vocabulary with its own `lm_head`, which is mapped back using a `d2t` table.
|
||||
|
||||
Convert the EAGLE-3 checkpoint with `--target-model-dir` so it inherits the target's tokenizer and the layer
|
||||
indices to read. Both the SpecForge `LlamaForCausalLMEagle3` and the vLLM/AngelSlim `Eagle3LlamaForCausalLM`
|
||||
checkpoint formats are supported (for example [`AngelSlim/Qwen3-4B_eagle3`](https://huggingface.co/AngelSlim/Qwen3-4B_eagle3)
|
||||
for `Qwen/Qwen3-4B`):
|
||||
|
||||
```bash
|
||||
python convert_hf_to_gguf.py AngelSlim/Qwen3-4B_eagle3 \
|
||||
--target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-eagle3.gguf
|
||||
|
||||
llama-server -m Qwen3-4B.gguf -md Qwen3-4B-eagle3.gguf --spec-type draft-eagle3
|
||||
```
|
||||
|
||||
Supported EAGLE-3 draft models include:
|
||||
|
||||
- [yuhuili/EAGLE3-LLaMA3.1-Instruct-8B](https://huggingface.co/yuhuili/EAGLE3-LLaMA3.1-Instruct-8B)
|
||||
- [yuhuili/EAGLE3-LLaMA3.3-Instruct-70B](https://huggingface.co/yuhuili/EAGLE3-LLaMA3.3-Instruct-70B)
|
||||
- [RedHatAI/gemma-4-31B-it-speculator.eagle3](https://huggingface.co/RedHatAI/gemma-4-31B-it-speculator.eagle3)
|
||||
- [RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3](https://huggingface.co/RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3)
|
||||
- [Tengyunw/qwen3_8b_eagle3](https://huggingface.co/Tengyunw/qwen3_8b_eagle3)
|
||||
- [Tengyunw/qwen3_30b_moe_eagle3](https://huggingface.co/Tengyunw/qwen3_30b_moe_eagle3)
|
||||
- [AngelSlim/Qwen3-1.7B_eagle3](https://huggingface.co/AngelSlim/Qwen3-1.7B_eagle3)
|
||||
- [AngelSlim/Qwen3-4B_eagle3](https://huggingface.co/AngelSlim/Qwen3-4B_eagle3)
|
||||
- [AngelSlim/Qwen3-8B_eagle3](https://huggingface.co/AngelSlim/Qwen3-8B_eagle3)
|
||||
- [AngelSlim/Qwen3-14B_eagle3](https://huggingface.co/AngelSlim/Qwen3-14B_eagle3)
|
||||
- [AngelSlim/Qwen3-32B_eagle3](https://huggingface.co/AngelSlim/Qwen3-32B_eagle3)
|
||||
- [AngelSlim/Qwen3-a3B_eagle3](https://huggingface.co/AngelSlim/Qwen3-a3B_eagle3)
|
||||
- [RedHatAI/gpt-oss-20b-speculator.eagle3](https://huggingface.co/RedHatAI/gpt-oss-20b-speculator.eagle3)
|
||||
- [lmsys/EAGLE3-gpt-oss-120b-bf16](https://huggingface.co/lmsys/EAGLE3-gpt-oss-120b-bf16)
|
||||
- [nvidia/gpt-oss-120b-Eagle3-long-context](https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-long-context)
|
||||
|
||||
For the full and up-to-date list of supported models, see #18039.
|
||||
|
||||
### n-gram Cache (`ngram-cache`)
|
||||
|
||||
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
|
||||
@@ -108,7 +147,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
|
||||
### General Speculative Parameters
|
||||
|
||||
```
|
||||
--spec-type [none|draft-simple|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
comma-separated list of types of speculative decoding to use
|
||||
(default: none)
|
||||
(env: LLAMA_ARG_SPEC_TYPE)
|
||||
@@ -247,6 +286,7 @@ Specifies a comma-separated list of speculative decoding types to use.
|
||||
|------|-------------|
|
||||
| `none` | No speculative decoding (default) |
|
||||
| `draft-simple` | Use a simple draft model for speculation |
|
||||
| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states |
|
||||
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
|
||||
| `ngram-cache` | Use n-gram cache lookup |
|
||||
| `ngram-simple` | Use simple n-gram pattern matching |
|
||||
|
||||
+1
-2
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 15)
|
||||
set(GGML_VERSION_PATCH 2)
|
||||
set(GGML_VERSION_PATCH 3)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
@@ -266,7 +266,6 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING
|
||||
"ggml: OpenCL API version to target")
|
||||
|
||||
option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF)
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)")
|
||||
|
||||
# toolchain for vulkan-shaders-gen
|
||||
set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen")
|
||||
|
||||
@@ -27,6 +27,14 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int de
|
||||
// split tensor buffer that splits matrices by rows across multiple devices
|
||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
|
||||
|
||||
// Tensor parallelism (--split-mode tensor): comm_init/free/allreduce_tensor
|
||||
// trio queried by the meta-backend via ggml_backend_reg_get_proc_address.
|
||||
// See typedefs in ggml/include/ggml-backend.h. Mirrors the CUDA backend's
|
||||
// pattern (ggml_backend_cuda_comm_*).
|
||||
GGML_BACKEND_API void * ggml_backend_sycl_comm_init(ggml_backend_t * backends, size_t n_backends);
|
||||
GGML_BACKEND_API void ggml_backend_sycl_comm_free(void * comm_ctx);
|
||||
GGML_BACKEND_API bool ggml_backend_sycl_comm_allreduce_tensor(void * comm_ctx, struct ggml_tensor ** tensors);
|
||||
|
||||
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
||||
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
|
||||
|
||||
|
||||
@@ -1551,6 +1551,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
int split_backend_id = split->backend_id;
|
||||
ggml_backend_t split_backend = sched->backends[split_backend_id];
|
||||
|
||||
ggml_backend_synchronize(split_backend);
|
||||
|
||||
// copy the input tensors to the split backend
|
||||
for (int input_id = 0; input_id < split->n_inputs; input_id++) {
|
||||
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]);
|
||||
@@ -1561,15 +1563,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
// inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done
|
||||
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
|
||||
ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]);
|
||||
} else {
|
||||
} else if (!split_backend->iface.cpy_tensor_async) {
|
||||
ggml_backend_synchronize(split_backend);
|
||||
}
|
||||
ggml_backend_tensor_copy(input, input_cpy);
|
||||
ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy);
|
||||
} else {
|
||||
// wait for the split backend to finish using the input before overwriting it
|
||||
if (sched->events[split_backend_id][sched->cur_copy] != NULL) {
|
||||
ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]);
|
||||
} else {
|
||||
} else if (!split_backend->iface.cpy_tensor_async) {
|
||||
ggml_backend_synchronize(split_backend);
|
||||
}
|
||||
|
||||
@@ -1674,6 +1676,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_synchronize(split_backend);
|
||||
|
||||
if (!sched->callback_eval) {
|
||||
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
|
||||
if (ec != GGML_STATUS_SUCCESS) {
|
||||
|
||||
+50
-23
@@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32(
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32(
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, x);
|
||||
float mean = sum/ne00;
|
||||
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
|
||||
const float * xf = (const float *) x;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
float variance = 0;
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, xf);
|
||||
float mean = sum/ne00;
|
||||
|
||||
float * yf = (float *) y;
|
||||
float variance = 0;
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
mean = -mean;
|
||||
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
||||
vDSP_measqv(y, 1, &variance, ne00);
|
||||
mean = -mean;
|
||||
vDSP_vsadd(xf, 1, &mean, yf, 1, ne00);
|
||||
vDSP_measqv(yf, 1, &variance, ne00);
|
||||
#else
|
||||
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
||||
variance = ggml_vec_cvar_f32(ne00, yf, xf, mean);
|
||||
#endif //GGML_USE_ACCELERATE
|
||||
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
ggml_vec_scale_f32(ne00, yf, scale);
|
||||
} else {
|
||||
float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += *(const float *) (x + i00*nb00);
|
||||
}
|
||||
const float mean = sum/ne00;
|
||||
|
||||
float variance = 0.0f;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const float v = *(const float *) (x + i00*nb00) - mean;
|
||||
*(float *) (y + i00*nb0) = v;
|
||||
variance += v * v;
|
||||
}
|
||||
variance /= ne00;
|
||||
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
*(float *) (y + i00*nb0) *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32(
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32(
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
const float xi = *(const float *) (x + i00*nb00);
|
||||
sum += (ggml_float)(xi * xi);
|
||||
}
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
|
||||
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
ggml_vec_scale_f32(ne00, (float *) y, scale);
|
||||
} else {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const float xi = *(const float *) (x + i00*nb00);
|
||||
*(float *) (y + i00*nb0) = xi * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,12 +75,12 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
|
||||
ay1 = GGML_F32_VEC_LOAD(y + i);
|
||||
sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1);
|
||||
}
|
||||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only
|
||||
// maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmla on available elements only
|
||||
if (np2 < n) {
|
||||
svbool_t pg = svwhilelt_b32(np2, n);
|
||||
ax1 = svld1_f32(pg, x + np2);
|
||||
ay1 = svld1_f32(pg, y + np2);
|
||||
sum1 = svmad_f32_m(pg, ax1, ay1, sum1);
|
||||
sum1 = svmla_f32_m(pg, sum1, ax1, ay1);
|
||||
}
|
||||
// reduce sum1,sum2 to sum1
|
||||
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
|
||||
|
||||
@@ -34,26 +34,26 @@ template <float (*bin_op)(const float, const float),
|
||||
static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const src1_t * src1,
|
||||
dst_t * dst,
|
||||
const int ne0,
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const uint32_t ne0,
|
||||
const uint32_t ne1,
|
||||
const uint32_t ne2,
|
||||
const uint3 ne3,
|
||||
const uint3 ne10,
|
||||
const uint3 ne11,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
/*const uint32_t s0,*/
|
||||
const uint32_t s1,
|
||||
const uint32_t s2,
|
||||
const uint32_t s3,
|
||||
const uint32_t s00,
|
||||
const uint32_t s01,
|
||||
const uint32_t s02,
|
||||
const uint32_t s03,
|
||||
const uint32_t s10,
|
||||
const uint32_t s11,
|
||||
const uint32_t s12,
|
||||
const uint32_t s13,
|
||||
src1_ptrs... src1s) {
|
||||
ggml_cuda_pdl_lc();
|
||||
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
@@ -61,7 +61,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
|
||||
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
|
||||
|
||||
if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
|
||||
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -69,25 +69,32 @@ static __global__ void k_bin_bcast(const src0_t * src0,
|
||||
const uint32_t i12 = fastmodulo(i2, ne12);
|
||||
const uint32_t i13 = fastmodulo(i3, ne13);
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01;
|
||||
const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11;
|
||||
const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1;
|
||||
|
||||
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const uint32_t s0 = blockDim.x * gridDim.x;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
|
||||
for (uint32_t i0 = i0s; i0 < ne0; i0 += s0) {
|
||||
const uint32_t i10 = fastmodulo(i0, ne10);
|
||||
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
|
||||
// protect i0 from overflow
|
||||
if (ne0 - i0 <= s0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,19 +117,19 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
const uint3 ne12,
|
||||
const uint3 ne13,
|
||||
/*const int s0,*/
|
||||
const int s1,
|
||||
const int s2,
|
||||
const int s3,
|
||||
const int s00,
|
||||
const int s01,
|
||||
const int s02,
|
||||
const int s03,
|
||||
const int s10,
|
||||
const int s11,
|
||||
const int s12,
|
||||
const int s13,
|
||||
const uint32_t s1,
|
||||
const uint32_t s2,
|
||||
const uint32_t s3,
|
||||
const uint32_t s00,
|
||||
const uint32_t s01,
|
||||
const uint32_t s02,
|
||||
const uint32_t s03,
|
||||
const uint32_t s10,
|
||||
const uint32_t s11,
|
||||
const uint32_t s12,
|
||||
const uint32_t s13,
|
||||
src1_ptrs... src1s) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
const uint32_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
const uint32_t i3 = fastdiv(i, prod_012);
|
||||
const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
|
||||
@@ -133,25 +140,25 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
|
||||
return;
|
||||
}
|
||||
|
||||
const int i11 = fastmodulo(i1, ne11);
|
||||
const int i12 = fastmodulo(i2, ne12);
|
||||
const int i13 = fastmodulo(i3, ne13);
|
||||
const uint32_t i11 = fastmodulo(i1, ne11);
|
||||
const uint32_t i12 = fastmodulo(i2, ne12);
|
||||
const uint32_t i13 = fastmodulo(i3, ne13);
|
||||
|
||||
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
||||
const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01;
|
||||
const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11;
|
||||
const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1;
|
||||
|
||||
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const int i10 = fastmodulo(i0, ne10);
|
||||
const uint32_t i10 = fastmodulo(i0, ne10);
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
|
||||
float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f;
|
||||
if constexpr (sizeof...(src1_ptrs) > 0) {
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
|
||||
result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10])));
|
||||
} else {
|
||||
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
|
||||
result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]);
|
||||
}
|
||||
|
||||
dst_row[i0] = (dst_t) result;
|
||||
@@ -248,6 +255,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
size_t s02 = nb02 / sizeof(src0_t);
|
||||
size_t s03 = nb03 / sizeof(src0_t);
|
||||
|
||||
GGML_ASSERT(ne0 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(ne1 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(ne2 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(ne3 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
//GGML_ASSERT(s0 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s1 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s2 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s3 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
GGML_ASSERT(s00 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s01 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s02 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s03 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
GGML_ASSERT(s10 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s11 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s12 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(s13 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
GGML_ASSERT(cne1[0] <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(cne1[1] <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(cne1[2] <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(cne1[3] <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
||||
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
||||
@@ -263,6 +295,8 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
||||
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
||||
|
||||
GGML_ASSERT(ne2 * ne3 <= std::numeric_limits<unsigned int>::max());
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
|
||||
@@ -281,7 +315,13 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
|
||||
|
||||
if (block_nums.z > 65535 || block_nums.y > 65535) {
|
||||
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
int64_t block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
|
||||
|
||||
GGML_ASSERT(block_num <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(block_num * block_size <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(ne0 * ne1 <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(ne0 * ne1 * ne2 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
|
||||
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
|
||||
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
|
||||
@@ -298,6 +338,10 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
|
||||
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(int64_t(block_nums.x) * block_dims.x <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(int64_t(block_nums.y) * block_dims.y <= std::numeric_limits<uint32_t>::max());
|
||||
GGML_ASSERT(int64_t(block_nums.z) * block_dims.z <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
|
||||
{
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
|
||||
|
||||
+80
-29
@@ -53,10 +53,10 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
|
||||
const int64_t nmat = ne / (ne00 * ne01);
|
||||
const int64_t n = ne00 * ne01;
|
||||
|
||||
const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
|
||||
const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
||||
const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
|
||||
const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
||||
const int64_t x = (int64_t) blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
|
||||
const int64_t y = (int64_t) blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
||||
const int64_t tx = (int64_t) blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
|
||||
const int64_t ty = (int64_t) blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
|
||||
|
||||
__shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
|
||||
int cur_tile_buf = 0;
|
||||
@@ -197,7 +197,7 @@ static void ggml_cpy_scalar_contiguous_cuda(
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar_contiguous<src_t, dst_t>, launch_params, cx, cdst, ne);
|
||||
}
|
||||
@@ -208,6 +208,14 @@ static void ggml_cpy_scalar_cuda(
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
const auto launch_scalar_generic = [&]() {
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params,
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
};
|
||||
|
||||
if (transposed) {
|
||||
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
|
||||
int64_t ne00n, ne01n, ne02n;
|
||||
@@ -224,20 +232,18 @@ static void ggml_cpy_scalar_cuda(
|
||||
int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
|
||||
int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
|
||||
int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
|
||||
GGML_ASSERT(grid_x < UINT_MAX);
|
||||
GGML_ASSERT(grid_y < USHRT_MAX);
|
||||
GGML_ASSERT(grid_z < USHRT_MAX);
|
||||
dim3 dimGrid(grid_x, grid_y, grid_z);
|
||||
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params,
|
||||
cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
GGML_ASSERT(grid_x <= INT_MAX);
|
||||
if (grid_y > USHRT_MAX || grid_z > USHRT_MAX) {
|
||||
launch_scalar_generic();
|
||||
} else {
|
||||
dim3 dimGrid(grid_x, grid_y, grid_z);
|
||||
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params,
|
||||
cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
} else {
|
||||
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream);
|
||||
ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params,
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
launch_scalar_generic();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,7 +254,7 @@ static void ggml_cpy_f32_q8_0_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK8_0 == 0);
|
||||
const int64_t num_blocks = ne / QK8_0;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -259,7 +265,7 @@ static void ggml_cpy_q8_0_f32_cuda(
|
||||
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
|
||||
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -271,7 +277,7 @@ static void ggml_cpy_f32_q4_0_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK4_0 == 0);
|
||||
const int64_t num_blocks = ne / QK4_0;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -284,7 +290,7 @@ static void ggml_cpy_q4_0_f32_cuda(
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
@@ -297,7 +303,7 @@ static void ggml_cpy_f32_q4_1_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK4_1 == 0);
|
||||
const int64_t num_blocks = ne / QK4_1;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -310,7 +316,7 @@ static void ggml_cpy_q4_1_f32_cuda(
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
@@ -323,7 +329,7 @@ static void ggml_cpy_f32_q5_0_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK5_0 == 0);
|
||||
const int64_t num_blocks = ne / QK5_0;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -336,7 +342,7 @@ static void ggml_cpy_q5_0_f32_cuda(
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
@@ -349,7 +355,7 @@ static void ggml_cpy_f32_q5_1_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK5_1 == 0);
|
||||
const int64_t num_blocks = ne / QK5_1;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
@@ -362,7 +368,7 @@ static void ggml_cpy_q5_1_f32_cuda(
|
||||
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
|
||||
cudaStream_t stream) {
|
||||
const int64_t num_blocks = ne;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
|
||||
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
||||
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
@@ -375,11 +381,51 @@ static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
|
||||
GGML_ASSERT(ne % QK4_NL == 0);
|
||||
const int64_t num_blocks = ne / QK4_NL;
|
||||
GGML_ASSERT(num_blocks < UINT_MAX);
|
||||
GGML_ASSERT(num_blocks <= INT_MAX);
|
||||
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
// check if a same-type copy reduces to a 2D strided copy (height rows of width
|
||||
// contiguous bytes), so it can use cudaMemcpy2DAsync instead of the scalar kernel
|
||||
static bool ggml_cuda_cpy_as_memcpy_2d(const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
size_t & width, size_t & height, size_t & spitch, size_t & dpitch) {
|
||||
// require matching shape: a reshaped copy maps elements by flat order, which the
|
||||
// prefix walk below does not handle
|
||||
if (src0->type != src1->type || !ggml_are_same_shape(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// grow the contiguous prefix block shared by both tensors
|
||||
size_t block_nb = ggml_element_size(src0);
|
||||
int d = 0;
|
||||
for (; d < GGML_MAX_DIMS; ++d) {
|
||||
if (src0->nb[d] != block_nb || src1->nb[d] != block_nb) {
|
||||
break;
|
||||
}
|
||||
block_nb *= src0->ne[d];
|
||||
}
|
||||
|
||||
// d == 0: nothing contiguous; d == GGML_MAX_DIMS: fully contiguous (handled by memcpy)
|
||||
if (d == 0 || d == GGML_MAX_DIMS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// dim d carries the rows; everything above it must be a single element
|
||||
for (int i = d + 1; i < GGML_MAX_DIMS; ++i) {
|
||||
if (src0->ne[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
width = block_nb;
|
||||
height = src0->ne[d];
|
||||
spitch = src0->nb[d];
|
||||
dpitch = src1->nb[d];
|
||||
|
||||
return spitch >= width && dpitch >= width;
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
@@ -415,6 +461,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
|
||||
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
|
||||
|
||||
size_t mc_width = 0, mc_height = 0, mc_spitch = 0, mc_dpitch = 0;
|
||||
|
||||
if (src0->type == src1->type && contiguous_srcs) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||
@@ -425,6 +473,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
{
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
} else if (ggml_cuda_cpy_as_memcpy_2d(src0, src1, mc_width, mc_height, mc_spitch, mc_dpitch)) {
|
||||
CUDA_CHECK(cudaMemcpy2DAsync(src1_ddc, mc_dpitch, src0_ddc, mc_spitch,
|
||||
mc_width, mc_height, cudaMemcpyDeviceToDevice, main_stream));
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_scalar_cuda<float, float, true>
|
||||
|
||||
@@ -3192,11 +3192,24 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
|
||||
ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
|
||||
ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
|
||||
|
||||
if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
|
||||
// Enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA
|
||||
// Excluding this path for HIP and MUSA as a precaution.
|
||||
// According to the summary in https://github.com/ggml-org/llama.cpp/pull/20793#issuecomment-4275794315, this change is not beneficial for hip anyways.
|
||||
// Additionally, there is a lot of anectodal evidence that hip/musa stream behavior might not always 1:1 match CUDA behavior.
|
||||
// e.g. https://github.com/ROCm/rocm-systems/issues/5109
|
||||
// It thus makes sense to exclude this path for HIP and MUSA. This PR was not aimed these backends, the majority of testing happened on CUDA.
|
||||
// This can be revisited in the future if enabling copy_from_host benefits hip/MUSA, and if the PR author can extensively test on these backends.
|
||||
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA)
|
||||
const bool copy_from_host = false;
|
||||
#else
|
||||
const bool copy_from_host = ggml_backend_buffer_is_host(buf_src) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU;
|
||||
#endif
|
||||
|
||||
if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) {
|
||||
if (!(copy_from_host || ggml_backend_buffer_is_cuda(buf_src)) || !ggml_backend_buffer_is_cuda(buf_dst)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -3207,14 +3220,17 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
|
||||
ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context;
|
||||
ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context;
|
||||
|
||||
if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
|
||||
if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) ||
|
||||
!copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) {
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
|
||||
#endif // NDEBUG
|
||||
return false;
|
||||
}
|
||||
|
||||
if (backend_src != backend_dst) {
|
||||
if (copy_from_host) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream()));
|
||||
} else if (backend_src != backend_dst) {
|
||||
// copy on src stream
|
||||
if (cuda_ctx_src->device == cuda_ctx_dst->device) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
|
||||
@@ -5334,7 +5350,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return true;
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
break;
|
||||
|
||||
@@ -2,6 +2,28 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
static __global__ void k_compute_out_prod_ptrs(
|
||||
const float * src0_d, const float * src1_d, float * dst_d,
|
||||
const float ** ptrs_a, const float ** ptrs_b, float ** ptrs_c,
|
||||
const int64_t ne2, const int64_t ne3,
|
||||
const int64_t dps2, const int64_t dps3,
|
||||
const size_t s02, const size_t s03,
|
||||
const size_t s12, const size_t s13,
|
||||
const size_t s2, const size_t s3) {
|
||||
const int64_t i2 = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int64_t i3 = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
|
||||
if (i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t idx = i3*ne2 + i2;
|
||||
|
||||
ptrs_a[idx] = src0_d + (i3/dps3)*s03 + (i2/dps2)*s02;
|
||||
ptrs_b[idx] = src1_d + i3 *s13 + i2 *s12;
|
||||
ptrs_c[idx] = dst_d + i3 *s3 + i2 *s2;
|
||||
}
|
||||
|
||||
void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
@@ -67,18 +89,39 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
&beta, dst_d + i3 *s3, ldc, s2,
|
||||
batch_count));
|
||||
}
|
||||
} else if (ne2 > 1 || ne3 > 1) {
|
||||
// dps2 > 1 (src0 broadcast along dim 2 with non-uniform stride) or multiple GEMMs
|
||||
// along dim 3: compute per-GEMM pointers on the device and use a single batched GEMM.
|
||||
GGML_ASSERT(ne3 > 0);
|
||||
GGML_ASSERT(ne2 <= (int64_t) std::numeric_limits<int>::max() / ne3);
|
||||
const int batch_count = (int) (ne2 * ne3);
|
||||
|
||||
ggml_cuda_pool_alloc<const float *> ptrs_a(ctx.pool(), batch_count);
|
||||
ggml_cuda_pool_alloc<const float *> ptrs_b(ctx.pool(), batch_count);
|
||||
ggml_cuda_pool_alloc< float *> ptrs_c(ctx.pool(), batch_count);
|
||||
|
||||
const dim3 block_dims(16, 16);
|
||||
const dim3 grid_dims((ne2 + block_dims.x - 1)/block_dims.x, (ne3 + block_dims.y - 1)/block_dims.y);
|
||||
k_compute_out_prod_ptrs<<<grid_dims, block_dims, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ptrs_a.get(), ptrs_b.get(), ptrs_c.get(),
|
||||
ne2, ne3, dps2, dps3, s02, s03, s12, s13, s2, s3);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemmBatched(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, ptrs_a.get(), lda,
|
||||
ptrs_b.get(), ldb,
|
||||
&beta, ptrs_c.get(), ldc,
|
||||
batch_count));
|
||||
} else {
|
||||
// Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2
|
||||
// with non-uniform stride; would need cublasSgemmBatched with pointer arrays).
|
||||
for (int64_t i3 = 0; i3 < ne3; ++i3) {
|
||||
for (int64_t i2 = 0; i2 < ne2; ++i2) {
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
|
||||
src1_d + i3 *s13 + i2 *s12, ldb,
|
||||
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
|
||||
}
|
||||
}
|
||||
// ne2 == 1 && ne3 == 1: single GEMM
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
|
||||
ne0, ne1, ne01,
|
||||
&alpha, src0_d, lda,
|
||||
src1_d, ldb,
|
||||
&beta, dst_d, ldc));
|
||||
}
|
||||
}
|
||||
|
||||
Vendored
+1
@@ -48,6 +48,7 @@
|
||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasSgemmBatched hipblasSgemmBatched
|
||||
#define cublasSgemmStridedBatched hipblasSgemmStridedBatched
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
|
||||
Vendored
+1
@@ -32,6 +32,7 @@
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasSgemmBatched mublasSgemmBatched
|
||||
#define cublasSgemmStridedBatched mublasSgemmStridedBatched
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasOperation_t mublasOperation_t
|
||||
|
||||
@@ -25,7 +25,6 @@ include(ExternalProject)
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
|
||||
|
||||
add_library(htp_iface OBJECT
|
||||
${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c)
|
||||
@@ -72,15 +71,12 @@ function(build_htp_skel V)
|
||||
-DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
|
||||
-DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
|
||||
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
|
||||
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
|
||||
-DDSP_VERSION=${V}
|
||||
-DPREBUILT_LIB_DIR="toolv19_${V}")
|
||||
list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
|
||||
set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
build_htp_skel(v68)
|
||||
build_htp_skel(v69)
|
||||
build_htp_skel(v73)
|
||||
build_htp_skel(v75)
|
||||
build_htp_skel(v79)
|
||||
|
||||
+1359
-1274
File diff suppressed because it is too large
Load Diff
@@ -5,10 +5,12 @@
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <stdio.h>
|
||||
#include "htp-ops.h"
|
||||
#include "htp/matmul-ops.h"
|
||||
|
||||
struct htp_opnode {
|
||||
ggml_tensor * node = nullptr;
|
||||
@@ -17,6 +19,13 @@ struct htp_opnode {
|
||||
|
||||
htp_op_code opcode = HTP_OP_INVALID;
|
||||
|
||||
std::vector<ggml_tensor *> extra_dsts;
|
||||
|
||||
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS] = {0};
|
||||
|
||||
htp_opnode(ggml_tensor * node = nullptr, std::vector<ggml_tensor *> fused = {}, htp_op_code opcode = HTP_OP_INVALID, std::vector<ggml_tensor *> extra_dsts = {})
|
||||
: node(node), fused(std::move(fused)), opcode(opcode), extra_dsts(std::move(extra_dsts)) {}
|
||||
|
||||
ggml_op op() const {
|
||||
return node->op;
|
||||
}
|
||||
@@ -25,6 +34,26 @@ struct htp_opnode {
|
||||
return fused.empty() ? node : fused.back();
|
||||
}
|
||||
|
||||
void add_fused(ggml_tensor * t, bool extra_dst = false) {
|
||||
fused.push_back(t);
|
||||
if (extra_dst) {
|
||||
extra_dsts.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const ggml_tensor *> get_outputs() const {
|
||||
std::vector<const ggml_tensor *> res;
|
||||
if (extra_dsts.empty()) {
|
||||
res.push_back(dst());
|
||||
} else {
|
||||
res.push_back(node);
|
||||
for (const auto * x : extra_dsts) {
|
||||
res.push_back(x);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
const ggml_tensor * src0() const {
|
||||
return node->src[0];
|
||||
}
|
||||
@@ -37,10 +66,6 @@ struct htp_opnode {
|
||||
return ggml_op_is_empty(node->op);
|
||||
}
|
||||
|
||||
void add_fused(ggml_tensor * t) {
|
||||
fused.push_back(t);
|
||||
}
|
||||
|
||||
bool stackable() const {
|
||||
switch (this->op()) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
@@ -131,87 +156,117 @@ struct htp_opformat {
|
||||
char types[16 * GGML_MAX_SRC];
|
||||
char buffs[64 * GGML_MAX_SRC];
|
||||
char names[64 * GGML_MAX_SRC];
|
||||
char kparams[128];
|
||||
|
||||
int format_tensor_dims(char * str, const struct ggml_tensor * t) {
|
||||
int format_tensor_dims(char * str, size_t max_size, const struct ggml_tensor * t) {
|
||||
if (!t) {
|
||||
return sprintf(str, "NONE");
|
||||
return snprintf(str, max_size, "NONE");
|
||||
}
|
||||
if (t->ne[2] == 1 && t->ne[3] == 1) {
|
||||
return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
|
||||
return snprintf(str, max_size, "%d:%d", (int) t->ne[0], (int) t->ne[1]);
|
||||
} else {
|
||||
return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
|
||||
return snprintf(str, max_size, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]);
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_dims(char * str, const htp_opnode & node) {
|
||||
void format_op_dims(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += format_tensor_dims(p, inputs[0]);
|
||||
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[0]), (size_t)(p_end - p));
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += format_tensor_dims(p, inputs[i]);
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[i]), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
char self[64];
|
||||
format_tensor_dims(self, node.dst());
|
||||
p += sprintf(p, "%s", self);
|
||||
format_tensor_dims(self, sizeof(self), node.dst());
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
int format_tensor_strides(char * str, const struct ggml_tensor * t) {
|
||||
int format_tensor_strides(char * str, size_t max_size, const struct ggml_tensor * t) {
|
||||
if (!t) {
|
||||
return sprintf(str, "NONE");
|
||||
return snprintf(str, max_size, "NONE");
|
||||
}
|
||||
const char * c = ggml_is_contiguous(t) ? "" : "!";
|
||||
|
||||
if (t->ne[2] == 1 && t->ne[3] == 1) {
|
||||
return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
|
||||
return snprintf(str, max_size, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c);
|
||||
} else {
|
||||
return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
|
||||
return snprintf(str, max_size, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c);
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_strides(char * str, const htp_opnode & node) {
|
||||
void format_op_strides(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += format_tensor_strides(p, inputs[0]);
|
||||
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[0]), (size_t)(p_end - p));
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += format_tensor_strides(p, inputs[i]);
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[i]), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
char self[64];
|
||||
format_tensor_strides(self, node.dst());
|
||||
p += sprintf(p, "%s", self);
|
||||
format_tensor_strides(self, sizeof(self), node.dst());
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_types(char * str, const htp_opnode & node) {
|
||||
void format_op_types(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE");
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", ggml_type_name(node.dst()->type));
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", ggml_type_name(node.dst()->type)), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
const char * tensor_buff_name(const struct ggml_tensor * t) {
|
||||
@@ -221,51 +276,102 @@ struct htp_opformat {
|
||||
return "NONE";
|
||||
}
|
||||
|
||||
void format_op_buffs(char * str, const htp_opnode & node) {
|
||||
void format_op_buffs(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += sprintf(p, "%s", tensor_buff_name(inputs[0]));
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", tensor_buff_name(inputs[i]));
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[0])), (size_t)(p_end - p));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[i])), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", tensor_buff_name(node.dst()));
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(node.dst())), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
void format_op_names(char * str, const htp_opnode & node) {
|
||||
void format_op_names(char * str, size_t max_size, const htp_opnode & node) {
|
||||
char * p = str;
|
||||
char * p_end = str + max_size;
|
||||
auto inputs = node.get_inputs();
|
||||
|
||||
if (!inputs.empty()) {
|
||||
p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE");
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
p += sprintf(p, " x ");
|
||||
p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE");
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? inputs[0]->name : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
|
||||
p += sprintf(p, " -> ");
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p));
|
||||
}
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? inputs[i]->name : "NONE"), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
|
||||
p += sprintf(p, "%s", node.dst()->name);
|
||||
if (p < p_end) {
|
||||
p += std::min((size_t)snprintf(p, p_end - p, "%s", node.dst()->name), (size_t)(p_end - p));
|
||||
}
|
||||
}
|
||||
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
|
||||
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
|
||||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
|
||||
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
|
||||
const char * path = "unknown";
|
||||
int32_t type = kparams->kernel_type;
|
||||
if (type == HTP_MM_KERNEL_HMX_2D || type == HTP_MM_KERNEL_HMX_F16_BATCHED) {
|
||||
path = "hmx-tiled";
|
||||
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || type == HTP_MM_KERNEL_HVX_F32_F32_VTCM ||
|
||||
type == HTP_MM_KERNEL_HVX_QUANT_ROW || type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) {
|
||||
path = "hvx-tiled";
|
||||
} else if (type == HTP_MM_KERNEL_HVX_F16_F16_DDR || type == HTP_MM_KERNEL_HVX_F16_F32_DDR ||
|
||||
type == HTP_MM_KERNEL_HVX_F32_F32_DDR || type == HTP_MM_KERNEL_HVX_F32_F16_DDR ||
|
||||
type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) {
|
||||
path = "hvx-flat";
|
||||
}
|
||||
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
|
||||
} else {
|
||||
snprintf(str, max_size, "----");
|
||||
}
|
||||
}
|
||||
|
||||
void format(const htp_opnode & node) {
|
||||
format_op_dims(dims, node);
|
||||
format_op_strides(strides, node);
|
||||
format_op_types(types, node);
|
||||
format_op_buffs(buffs, node);
|
||||
format_op_names(names, node);
|
||||
format_op_dims(dims, sizeof(dims), node);
|
||||
format_op_strides(strides, sizeof(strides), node);
|
||||
format_op_types(types, sizeof(types), node);
|
||||
format_op_buffs(buffs, sizeof(buffs), node);
|
||||
format_op_names(names, sizeof(names), node);
|
||||
format_kernel_params(kparams, sizeof(kparams), node);
|
||||
}
|
||||
|
||||
htp_opformat() {}
|
||||
htp_opformat() {
|
||||
strides[0] = '\0';
|
||||
dims[0] = '\0';
|
||||
types[0] = '\0';
|
||||
buffs[0] = '\0';
|
||||
names[0] = '\0';
|
||||
kparams[0] = '\0';
|
||||
}
|
||||
htp_opformat(const htp_opnode & node) { format(node); }
|
||||
};
|
||||
|
||||
|
||||
@@ -19,43 +19,9 @@ add_library(${HTP_LIB} SHARED
|
||||
htp_iface_skel.c
|
||||
worker-pool.c
|
||||
hex-dma.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>
|
||||
FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE})
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
# HMX acceleration: available on v73+ architectures
|
||||
set(HTP_HMX_VERSIONS v73 v75 v79 v81)
|
||||
list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)
|
||||
|
||||
if (_hmx_idx GREATER_EQUAL 0)
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
)
|
||||
|
||||
# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
|
||||
set_source_files_properties(
|
||||
hmx-flash-attn-ops.c
|
||||
hmx-matmul-ops.c
|
||||
hmx-queue.c
|
||||
PROPERTIES COMPILE_OPTIONS "-mhmx"
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1)
|
||||
endif()
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
target_sources(${HTP_LIB} PRIVATE
|
||||
hmx-queue.c
|
||||
flash-attn-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
@@ -63,7 +29,6 @@ target_sources(${HTP_LIB} PRIVATE
|
||||
softmax-ops.c
|
||||
act-ops.c
|
||||
rope-ops.c
|
||||
flash-attn-ops.c
|
||||
set-rows-ops.c
|
||||
get-rows-ops.c
|
||||
cpy-ops.c
|
||||
@@ -79,6 +44,17 @@ target_sources(${HTP_LIB} PRIVATE
|
||||
pad-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
install(TARGETS ${HTP_LIB})
|
||||
|
||||
@@ -3,7 +3,7 @@ if (HEXAGON_TOOLCHAIN_INCLUDED)
|
||||
endif()
|
||||
set(HEXAGON_TOOLCHAIN_INCLUDED true)
|
||||
|
||||
#Cross Compiling for Hexagon
|
||||
# Cross Compiling for Hexagon
|
||||
set(HEXAGON TRUE)
|
||||
set(CMAKE_SYSTEM_NAME QURT)
|
||||
set(CMAKE_SYSTEM_PROCESSOR Hexagon)
|
||||
@@ -14,7 +14,6 @@ set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)
|
||||
set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
|
||||
set(CUSTOM_RUNELF_PATH "")
|
||||
|
||||
#To fix backward compatibility with EAI addon.
|
||||
if (NOT HEXAGON_SDK_ROOT)
|
||||
set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT})
|
||||
endif()
|
||||
@@ -31,7 +30,6 @@ endif()
|
||||
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
|
||||
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
|
||||
|
||||
#Get the Binary extension of the Hexagon Toolchain
|
||||
if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows)
|
||||
set(HEXAGON_TOOLCHAIN_SUFFIX .exe)
|
||||
endif()
|
||||
@@ -48,12 +46,12 @@ set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES
|
||||
HEXAGON_TOOLS_ROOT
|
||||
)
|
||||
|
||||
#QURT Related includes and linker flags
|
||||
# QURT Related includes and linker flags
|
||||
set(V_ARCH ${HEXAGON_ARCH})
|
||||
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}")
|
||||
set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||
|
||||
if( ${TREE} MATCHES PAKMAN )
|
||||
if (${TREE} MATCHES PAKMAN)
|
||||
set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}")
|
||||
endif()
|
||||
message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}")
|
||||
@@ -83,11 +81,9 @@ set(QURT_START_LINK_LIBS
|
||||
)
|
||||
STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}")
|
||||
|
||||
set(QURT_END_LINK_LIBS
|
||||
${TARGET_DIR}/fini.o
|
||||
)
|
||||
set(QURT_END_LINK_LIBS ${TARGET_DIR}/fini.o)
|
||||
|
||||
#Non QURT related includes and linker flags
|
||||
# Non QURT related includes and linker flags
|
||||
|
||||
set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}")
|
||||
|
||||
@@ -99,8 +95,10 @@ if (NOT NO_WRAP_MEM_API)
|
||||
set(WRAP_MEMALIGN -Wl,--wrap=memalign)
|
||||
endif()
|
||||
|
||||
set(ARCH_FLAGS "-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -mhmx")
|
||||
|
||||
set(PIC_SHARED_LD_FLAGS
|
||||
-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH}
|
||||
${ARCH_FLAGS}
|
||||
-G0
|
||||
-fpic
|
||||
-Wl,-Bsymbolic
|
||||
@@ -120,13 +118,13 @@ STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}")
|
||||
|
||||
set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}")
|
||||
|
||||
#System include paths
|
||||
# System include paths
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs)
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef)
|
||||
include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs)
|
||||
|
||||
#LLVM toolchain setup
|
||||
#Compiler paths, options and architecture
|
||||
# LLVM toolchain setup
|
||||
# Compiler paths, options and architecture
|
||||
set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX})
|
||||
@@ -137,8 +135,8 @@ set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon)
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,")
|
||||
set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,")
|
||||
|
||||
#Compiler Options
|
||||
set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
# Compiler Options
|
||||
set(COMMON_FLAGS "${ARCH_FLAGS} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}")
|
||||
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g")
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g")
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
#include "htp-ctx.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp-ops.h"
|
||||
#include "hmx-ops.h"
|
||||
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx);
|
||||
|
||||
// Must be multiple of 32
|
||||
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)
|
||||
@@ -633,7 +634,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
return HTP_STATUS_NO_SUPPORT;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
// HMX path: head_dim multiple of 64, F16 KV, and no sinks
|
||||
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) {
|
||||
int ret = hmx_flash_attn_ext(octx);
|
||||
@@ -642,7 +642,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
}
|
||||
// VTCM too small or other failure -> fall through to HVX path
|
||||
}
|
||||
#endif
|
||||
|
||||
struct htp_fa_context factx;
|
||||
factx.octx = octx;
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
#ifndef HEX_COMMON_H
|
||||
#define HEX_COMMON_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef SIZE_MAX
|
||||
#define SIZE_MAX ((size_t)-1)
|
||||
#endif
|
||||
|
||||
#ifndef MAX
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
static inline uint32_t hex_ceil_pow2(uint32_t x) {
|
||||
if (x <= 1) { return 1; }
|
||||
int p = 2;
|
||||
x--;
|
||||
while (x >>= 1) { p <<= 1; }
|
||||
return p;
|
||||
}
|
||||
|
||||
static inline size_t hmx_ceil_div(size_t num, size_t den) {
|
||||
return (num + den - 1) / den;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_up(size_t v, size_t align) {
|
||||
return hmx_ceil_div(v, align) * align;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_down(size_t v, size_t align) {
|
||||
return (v / align) * align;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
|
||||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
static inline size_t hex_smin(size_t a, size_t b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_swap_ptr(void ** p1, void ** p2) {
|
||||
void * t = *p1;
|
||||
*p1 = *p2;
|
||||
*p2 = t;
|
||||
}
|
||||
|
||||
static inline bool hex_mul_overflow(size_t a, size_t b, size_t *out) {
|
||||
if (a != 0 && b > SIZE_MAX / a) return true;
|
||||
*out = a * b;
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool hex_add_overflow(size_t a, size_t b, size_t *out) {
|
||||
if (a > SIZE_MAX - b) return true;
|
||||
*out = a + b;
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif // HEX_COMMON_H
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <hexagon_types.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include "hex-utils.h"
|
||||
|
||||
#include "hex-profile.h"
|
||||
|
||||
@@ -127,13 +128,8 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src)
|
||||
return p;
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ < 73
|
||||
static const uint32_t dma_src_l2_bypass_on = 1;
|
||||
static const uint32_t dma_dst_l2_bypass_on = 0;
|
||||
#else
|
||||
static const uint32_t dma_src_l2_bypass_on = 1;
|
||||
static const uint32_t dma_dst_l2_bypass_on = 1;
|
||||
#endif
|
||||
|
||||
static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) {
|
||||
if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) {
|
||||
|
||||
@@ -11,14 +11,7 @@
|
||||
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-dump.h"
|
||||
|
||||
#ifndef MAX
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#endif
|
||||
|
||||
#ifndef MIN
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#endif
|
||||
#include "hex-common.h"
|
||||
|
||||
static inline uint64_t hex_get_cycles() {
|
||||
uint64_t cycles = 0;
|
||||
@@ -32,54 +25,6 @@ static inline uint64_t hex_get_pktcnt() {
|
||||
return pktcnt;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_ceil_pow2(uint32_t x) {
|
||||
if (x <= 1) { return 1; }
|
||||
int p = 2;
|
||||
x--;
|
||||
while (x >>= 1) { p <<= 1; }
|
||||
return p;
|
||||
}
|
||||
|
||||
static inline size_t hmx_ceil_div(size_t num, size_t den) {
|
||||
return (num + den - 1) / den;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_aligned(const void * addr, uint32_t align) {
|
||||
return ((size_t) addr & (align - 1)) == 0;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_up(size_t v, size_t align) {
|
||||
return hmx_ceil_div(v, align) * align;
|
||||
}
|
||||
|
||||
static inline size_t hex_align_down(size_t v, size_t align) {
|
||||
return (v / align) * align;
|
||||
}
|
||||
|
||||
static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
|
||||
uint32_t left_off = (size_t) addr & (chunk_size - 1);
|
||||
uint32_t right_off = left_off + n;
|
||||
return right_off <= chunk_size;
|
||||
}
|
||||
|
||||
static inline uint32_t hex_round_up(uint32_t n, uint32_t m) {
|
||||
return m * ((n + m - 1) / m);
|
||||
}
|
||||
|
||||
static inline size_t hex_smin(size_t a, size_t b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
static inline size_t hex_smax(size_t a, size_t b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
static inline void hex_swap_ptr(void ** p1, void ** p2) {
|
||||
void * t = *p1;
|
||||
*p1 = *p2;
|
||||
*p2 = t;
|
||||
}
|
||||
|
||||
static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) {
|
||||
const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height));
|
||||
Q6_l2fetch_AP((void *) p, control);
|
||||
|
||||
@@ -49,7 +49,7 @@
|
||||
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
|
||||
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
|
||||
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
|
||||
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) {
|
||||
static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
|
||||
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
|
||||
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
|
||||
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
|
||||
@@ -70,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV,
|
||||
+ k_dma_size * 2 // K DMA x2
|
||||
+ v_dma_size * 2 // V DMA x2
|
||||
+ k_tile_size * 1 // K tiles
|
||||
+ v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
|
||||
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
|
||||
+ s_tile_size * 2 // S + P
|
||||
+ d_tile_size * 1 // D (diagonal matrix)
|
||||
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
|
||||
@@ -290,7 +290,7 @@ static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) =
|
||||
|
||||
struct hmx_fa_context {
|
||||
const struct htp_ops_context * octx;
|
||||
bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
|
||||
bool pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2
|
||||
uint32_t n_threads;
|
||||
|
||||
// Op parameters
|
||||
@@ -409,7 +409,7 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data)
|
||||
return;
|
||||
}
|
||||
|
||||
__fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
|
||||
__fp16 * v_tiles_dest = factx->pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0];
|
||||
|
||||
struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL;
|
||||
htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start);
|
||||
@@ -1312,13 +1312,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS);
|
||||
|
||||
const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc;
|
||||
const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
|
||||
const bool pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2);
|
||||
|
||||
// Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1
|
||||
const uint32_t n_threads = use_pipeline ? n_threads_init : 1;
|
||||
const uint32_t n_threads = pipeline ? n_threads_init : 1;
|
||||
|
||||
FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu",
|
||||
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget);
|
||||
neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, pipeline, vtcm_budget);
|
||||
|
||||
// ======== Build context ========
|
||||
struct hmx_fa_context factx;
|
||||
@@ -1339,7 +1339,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
factx.n_kv_blocks = n_kv_blocks;
|
||||
factx.is_q_fp32 = (q->type == HTP_TYPE_F32);
|
||||
factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32);
|
||||
factx.use_pipeline = use_pipeline;
|
||||
factx.pipeline = pipeline;
|
||||
factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1);
|
||||
|
||||
// Extract op parameters (mutable during softcap adjustment, then stored as const in factx)
|
||||
@@ -1405,7 +1405,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes);
|
||||
factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes);
|
||||
factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
|
||||
if (use_pipeline) {
|
||||
if (pipeline) {
|
||||
factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes);
|
||||
} else {
|
||||
factx.vtcm_v_tiles[1] = NULL;
|
||||
@@ -1456,7 +1456,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
// ======== HMX lock strategy ========
|
||||
// Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend.
|
||||
// Fallback: main thread holds the lock (original behavior).
|
||||
if (!factx.use_pipeline) {
|
||||
if (!factx.pipeline) {
|
||||
HAP_compute_res_hmx_lock(ctx->vtcm_rctx);
|
||||
}
|
||||
|
||||
@@ -1550,7 +1550,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
const size_t k_src_stride = size_k_row_padded / sizeof(__fp16);
|
||||
const size_t v_src_stride = size_v_row_padded / sizeof(__fp16);
|
||||
|
||||
if (factx.use_pipeline) {
|
||||
if (factx.pipeline) {
|
||||
// ==================================================================
|
||||
// Pipeline path: HVX phases ‖ HMX queue worker
|
||||
// ==================================================================
|
||||
@@ -1780,7 +1780,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br);
|
||||
|
||||
// HMX: O_final = diag(1/l) @ O_prev
|
||||
if (factx.use_pipeline) {
|
||||
if (factx.pipeline) {
|
||||
on_job.o_curr = o_tile_curr;
|
||||
on_job.o_prev = o_tile_prev;
|
||||
on_job.d_tiles = factx.vtcm_d_tiles;
|
||||
@@ -1826,7 +1826,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
|
||||
} // end KV head loop
|
||||
} // end batch loop
|
||||
|
||||
if (factx.use_pipeline) {
|
||||
if (factx.pipeline) {
|
||||
hmx_queue_suspend(ctx->hmx_queue);
|
||||
} else {
|
||||
HAP_compute_res_hmx_unlock(ctx->vtcm_rctx);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,6 +0,0 @@
|
||||
// HMX operations compiled as a single translation unit.
|
||||
// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO.
|
||||
|
||||
#include "hmx-queue.c"
|
||||
#include "hmx-matmul-ops.c"
|
||||
#include "hmx-flash-attn-ops.c"
|
||||
@@ -1,88 +0,0 @@
|
||||
// HMX operation entry-point declarations.
|
||||
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)
|
||||
|
||||
#ifndef HMX_OPS_H
|
||||
#define HMX_OPS_H
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "htp-ops.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
float *dst;
|
||||
const float *activation;
|
||||
const __fp16 *permuted_weight;
|
||||
int m;
|
||||
int k;
|
||||
int n;
|
||||
int act_stride;
|
||||
int weight_stride;
|
||||
int dst_stride;
|
||||
int ne02;
|
||||
int ne03;
|
||||
int ne12;
|
||||
int ne13;
|
||||
size_t src0_nb2;
|
||||
size_t src0_nb3;
|
||||
size_t src1_nb2;
|
||||
size_t src1_nb3;
|
||||
size_t dst_nb2;
|
||||
size_t dst_nb3;
|
||||
} hmx_matmul_f16_f32_batched_params_t;
|
||||
|
||||
// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
|
||||
// act_stride: activation row stride in elements (= k for contiguous, or
|
||||
// nb[1]/sizeof(float) for permuted tensors like attention Q).
|
||||
// weight_stride: weight row stride in elements (= k for compact weights, or
|
||||
// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
|
||||
int hmx_matmul_f16_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const __fp16 *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int act_stride,
|
||||
int weight_stride);
|
||||
|
||||
// Batched F16 wrapper over hmx_mat_mul_f16_f32.
|
||||
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
|
||||
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params);
|
||||
|
||||
// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4)
|
||||
int hmx_matmul_2d_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const uint8_t *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int act_stride,
|
||||
int weight_stride,
|
||||
int weight_type);
|
||||
|
||||
struct mmid_row_mapping;
|
||||
|
||||
int hmx_matmul_id_2d_f32(struct htp_context *ctx,
|
||||
float *restrict dst,
|
||||
const float *activation,
|
||||
const uint8_t *permuted_weight,
|
||||
int m, int k, int n,
|
||||
int ne11,
|
||||
size_t act_nb1, size_t act_nb2,
|
||||
size_t dst_nb1, size_t dst_nb2,
|
||||
int weight_stride,
|
||||
int weight_type,
|
||||
const struct mmid_row_mapping *matrix_rows,
|
||||
int cur_a,
|
||||
int mapping_stride);
|
||||
|
||||
// HMX flash attention
|
||||
int hmx_flash_attn_ext(struct htp_ops_context * octx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // HMX_OPS_H
|
||||
@@ -13,7 +13,9 @@
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef HTP_MAX_NTHREADS
|
||||
#define HTP_MAX_NTHREADS 10
|
||||
#endif
|
||||
#define HTP_MAX_MMAPS 16
|
||||
|
||||
// Memory mapping
|
||||
@@ -42,9 +44,13 @@ struct htp_ops_context {
|
||||
|
||||
enum htp_op_code op; // FIXME: rename to opcode
|
||||
int32_t op_params[HTP_OP_MAX_PARAMS];
|
||||
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS];
|
||||
|
||||
const struct htp_tensor * src[HTP_OP_MAX_INPUTS];
|
||||
const struct htp_tensor * dst;
|
||||
union {
|
||||
const struct htp_tensor * dst;
|
||||
const struct htp_tensor * dsts[HTP_OP_MAX_OUTPUTS];
|
||||
};
|
||||
|
||||
// TODO convert these to an array
|
||||
struct htp_spad src0_spad;
|
||||
@@ -87,13 +93,13 @@ struct htp_context {
|
||||
|
||||
struct htp_ops_context octx;
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap
|
||||
#endif
|
||||
};
|
||||
|
||||
int op_matmul(struct htp_ops_context * octx);
|
||||
int op_matmul_id(struct htp_ops_context * octx);
|
||||
int op_matmul_qkv(struct htp_ops_context * octx);
|
||||
int op_matmul_ffn(struct htp_ops_context * octx);
|
||||
int op_binary(struct htp_ops_context * octx);
|
||||
int op_unary(struct htp_ops_context * octx);
|
||||
int op_sum_rows(struct htp_ops_context * octx);
|
||||
|
||||
@@ -28,18 +28,19 @@ enum htp_data_type {
|
||||
HTP_TYPE_MXFP4 = 39,
|
||||
|
||||
// types used internally for repack, dyn.quant, etc
|
||||
HTP_TYPE_Q4_0x4x2 = 200,
|
||||
HTP_TYPE_Q4_1x4x2,
|
||||
HTP_TYPE_Q8_0x4x2,
|
||||
HTP_TYPE_MXFP4x4x2,
|
||||
HTP_TYPE_Q4_0_TILED = 200,
|
||||
HTP_TYPE_Q4_1_TILED,
|
||||
HTP_TYPE_Q8_0_TILED,
|
||||
HTP_TYPE_MXFP4_TILED,
|
||||
|
||||
HTP_TYPE_INVALID
|
||||
};
|
||||
|
||||
// Constats for internal types
|
||||
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
|
||||
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks
|
||||
#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks
|
||||
#define QK_Q4_0_TILED 256 // 32x32 Q4_0 tiled layout
|
||||
#define QK_Q8_0_TILED 128 // 32x32 Q8_0 tiled layout
|
||||
#define QK_MXFP4_TILED 256 // 32x32 MXFP4 tiled layout
|
||||
|
||||
|
||||
|
||||
// Mask to enable various stages of the Ops.
|
||||
@@ -57,6 +58,8 @@ enum htp_op_code {
|
||||
HTP_OP_DIV = 3,
|
||||
HTP_OP_MUL_MAT,
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_MUL_MAT_QKV,
|
||||
HTP_OP_MUL_MAT_FFN,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_RMS_NORM_MUL,
|
||||
HTP_OP_UNARY_SILU,
|
||||
@@ -99,7 +102,9 @@ enum htp_op_code {
|
||||
|
||||
#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS
|
||||
#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS
|
||||
#define HTP_OP_MAX_OUTPUTS 4
|
||||
#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS
|
||||
#define HTP_OP_MAX_KERN_PARAMS 32
|
||||
|
||||
#define HTP_OP_MAX_BUFS 16
|
||||
#define HTP_OP_MAX_REQS 256
|
||||
@@ -142,8 +147,10 @@ struct htp_op_desc {
|
||||
uint32_t opcode; // GGML/HTP Op
|
||||
uint32_t flags; // Op flags
|
||||
int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm
|
||||
int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; // generic blob for host-precomputed parameters
|
||||
uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices
|
||||
uint16_t dst; // Output tensor index
|
||||
uint16_t dst[HTP_OP_MAX_OUTPUTS]; // Output tensor indices
|
||||
uint16_t pad[2]; // padding to align to 64 bits
|
||||
};
|
||||
|
||||
#ifndef HTP_MAX_NTHREADS
|
||||
|
||||
@@ -11,12 +11,13 @@ struct htp_iface_pmu_conf {
|
||||
};
|
||||
|
||||
interface htp_iface : remote_handle64 {
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem);
|
||||
AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 n_hmx, in uint64 max_vmem);
|
||||
AEEResult stop();
|
||||
AEEResult mmap(in uint32 fd, in uint32 size);
|
||||
AEEResult munmap(in uint32 fd);
|
||||
AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu);
|
||||
AEEResult etm(in uint32 enable);
|
||||
AEEResult hwinfo(rout uint32 n_threads, rout uint32 n_hvx, rout uint32 n_hmx, rout uint64 vtcm_size);
|
||||
};
|
||||
|
||||
#endif /* HTP_IDL */
|
||||
|
||||
@@ -170,25 +170,7 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
|
||||
}
|
||||
#endif
|
||||
|
||||
/* Q6_Vsf_equals_Vw is only available on v73+.*/
|
||||
#if __HVX_ARCH__ < 73
|
||||
static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in)
|
||||
{
|
||||
HVX_Vector const vzero = Q6_V_vzero();
|
||||
HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero);
|
||||
HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in);
|
||||
HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift);
|
||||
HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift);
|
||||
HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized);
|
||||
HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp));
|
||||
return ret;
|
||||
}
|
||||
|
||||
static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
|
||||
{
|
||||
return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in));
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
|
||||
// This looks complicated.
|
||||
@@ -305,4 +287,17 @@ static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) {
|
||||
|
||||
#endif // __HVX_ARCH__ < 79
|
||||
|
||||
static inline HVX_Vector hvx_vec_load_act_tile(const uint8_t * y_q, uint32_t kt, HVX_Vector * v_act_all) {
|
||||
if (kt % 4 == 0) {
|
||||
*v_act_all = hvx_vmem(y_q + kt * 32);
|
||||
return *v_act_all;
|
||||
} else if (kt % 4 == 1) {
|
||||
return Q6_V_vror_VR(*v_act_all, 32);
|
||||
} else if (kt % 4 == 2) {
|
||||
return Q6_V_vror_VR(*v_act_all, 64);
|
||||
} else {
|
||||
return Q6_V_vror_VR(*v_act_all, 96);
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HVX_BASE_H */
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -361,7 +361,7 @@ static void vtcm_free(struct htp_context * ctx) {
|
||||
static void htp_packet_callback(dspqueue_t queue, int error, void * context);
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context);
|
||||
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) {
|
||||
AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp_queue_id, uint32_t n_hvx, uint32_t n_hmx, uint64_t max_vmem) {
|
||||
struct htp_context * ctx = (struct htp_context *) handle;
|
||||
|
||||
if (!ctx) {
|
||||
@@ -395,10 +395,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
return AEE_ENOMEMORY;
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
ctx->hmx_enabled = use_hmx;
|
||||
ctx->hmx_enabled = n_hmx;
|
||||
ctx->hmx_queue = NULL;
|
||||
if (use_hmx) {
|
||||
if (n_hmx) {
|
||||
ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx);
|
||||
if (ctx->hmx_queue) {
|
||||
ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS];
|
||||
@@ -407,8 +406,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que
|
||||
ctx->hmx_enabled = false;
|
||||
}
|
||||
}
|
||||
FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx);
|
||||
#endif
|
||||
FARF(HIGH, "HMX %s (n_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", n_hmx);
|
||||
|
||||
qurt_sysenv_max_hthreads_t hw_threads;
|
||||
qurt_sysenv_get_max_hw_threads(&hw_threads);
|
||||
@@ -481,13 +479,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
|
||||
dma_queue_delete(ctx->dma[i]);
|
||||
}
|
||||
|
||||
#ifdef HTP_HAS_HMX
|
||||
if (ctx->hmx_queue) {
|
||||
hmx_queue_delete(ctx->hmx_queue);
|
||||
ctx->hmx_queue = NULL;
|
||||
}
|
||||
ctx->hmx_enabled = false;
|
||||
#endif
|
||||
|
||||
vtcm_free(ctx);
|
||||
|
||||
@@ -500,6 +496,36 @@ AEEResult htp_iface_stop(remote_handle64 handle) {
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
AEEResult htp_iface_hwinfo(remote_handle64 handle, uint32_t * n_threads, uint32_t * n_hvx, uint32_t * n_hmx, uint64_t * vtcm_size) {
|
||||
(void)handle;
|
||||
if (!n_threads || !n_hvx || !n_hmx || !vtcm_size) {
|
||||
return AEE_EBADPARM;
|
||||
}
|
||||
|
||||
qurt_sysenv_max_hthreads_t hw_threads;
|
||||
qurt_sysenv_get_max_hw_threads(&hw_threads);
|
||||
uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF;
|
||||
|
||||
uint32_t n_hvx_val = hw_nhvx;
|
||||
if (n_hvx_val > hw_threads.max_hthreads) {
|
||||
n_hvx_val = hw_threads.max_hthreads;
|
||||
}
|
||||
if (n_hvx_val > HTP_MAX_NTHREADS) {
|
||||
n_hvx_val = HTP_MAX_NTHREADS;
|
||||
}
|
||||
|
||||
// for now we force n_threads == n_hvx
|
||||
*n_threads = n_hvx_val;
|
||||
*n_hvx = n_hvx_val;
|
||||
*n_hmx = 1;
|
||||
|
||||
uint32_t vtcm_sz = 8 * 1024 * 1024; // 8MB default fallback
|
||||
HAP_compute_res_query_VTCM(0, (unsigned int *)&vtcm_sz, NULL, NULL, NULL);
|
||||
*vtcm_size = vtcm_sz;
|
||||
|
||||
return AEE_SUCCESS;
|
||||
}
|
||||
|
||||
static void htp_error_callback(dspqueue_t queue, int error, void * context) {
|
||||
// No errors expected on the DSP.
|
||||
FARF(ERROR, "Error callback: 0x%08x", (unsigned) error);
|
||||
@@ -554,6 +580,12 @@ static int execute_op(struct htp_ops_context * octx) {
|
||||
case HTP_OP_MUL_MAT_ID:
|
||||
return op_matmul_id(octx);
|
||||
|
||||
case HTP_OP_MUL_MAT_QKV:
|
||||
return op_matmul_qkv(octx);
|
||||
|
||||
case HTP_OP_MUL_MAT_FFN:
|
||||
return op_matmul_ffn(octx);
|
||||
|
||||
case HTP_OP_MUL:
|
||||
case HTP_OP_ADD:
|
||||
case HTP_OP_SUB:
|
||||
@@ -762,8 +794,9 @@ static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, str
|
||||
}
|
||||
}
|
||||
|
||||
static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
|
||||
static int proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) {
|
||||
memcpy(octx->op_params, op->params, sizeof(octx->op_params));
|
||||
memcpy(octx->kernel_params, op->kernel_params, sizeof(octx->kernel_params));
|
||||
octx->flags = op->flags;
|
||||
octx->op = op->opcode;
|
||||
|
||||
@@ -785,22 +818,41 @@ static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens,
|
||||
src->ne[0], src->ne[1], src->ne[3], src->ne[3]);
|
||||
}
|
||||
|
||||
// Prep output tensor
|
||||
struct htp_tensor *dst = tens + op->dst;
|
||||
// Prep output tensors
|
||||
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
|
||||
uint16_t dst_idx = op->dst[i];
|
||||
if (dst_idx == 0xffff) {
|
||||
octx->dsts[i] = NULL;
|
||||
continue;
|
||||
}
|
||||
struct htp_tensor *dst = tens + dst_idx;
|
||||
octx->dsts[i] = dst;
|
||||
|
||||
octx->dst = dst;
|
||||
FARF(HIGH, "prep-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, dst_idx, (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
}
|
||||
|
||||
FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
|
||||
int status = execute_op(octx);
|
||||
|
||||
(void) execute_op(octx);
|
||||
octx->src0_spad.src = NULL;
|
||||
octx->src1_spad.src = NULL;
|
||||
octx->src2_spad.src = NULL;
|
||||
octx->src3_spad.src = NULL;
|
||||
octx->dst_spad.src = NULL;
|
||||
|
||||
// flush buffers on output
|
||||
hex_l2flush((void *) dst->data, dst->size);
|
||||
dst->flags |= HTP_TENSOR_FLUSHED;
|
||||
for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) {
|
||||
if (octx->dsts[i]) {
|
||||
struct htp_tensor *dst = (struct htp_tensor *)octx->dsts[i];
|
||||
hex_l2flush((void *) dst->data, dst->size);
|
||||
dst->flags |= HTP_TENSOR_FLUSHED;
|
||||
|
||||
FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]);
|
||||
FARF(HIGH, "post-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, op->dst[i], (void*) dst->data, dst->size,
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]);
|
||||
}
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
#define DSPQUEUE_POLL_TIMEOUT_USEC 100
|
||||
@@ -892,20 +944,26 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
}
|
||||
}
|
||||
|
||||
int op_status = HTP_STATUS_OK;
|
||||
uint32_t op_wakeup = n_ops / 2; // half-way throgh the batch
|
||||
|
||||
for (uint32_t i=0; i < n_ops; i++) {
|
||||
struct profile_data prof;
|
||||
|
||||
if (i == (n_ops-1)) {
|
||||
// wake up the host before starting the last op
|
||||
if (i == op_wakeup) {
|
||||
dspqueue_write_early_wakeup_noblock(queue, 0, 0);
|
||||
}
|
||||
|
||||
profile_start(ctx->profiler, &prof);
|
||||
|
||||
proc_op_req(octx, tens, i, &ops[i]);
|
||||
op_status = proc_op_req(octx, tens, i, &ops[i]);
|
||||
|
||||
profile_stop(ctx->profiler, &prof);
|
||||
|
||||
if (op_status != HTP_STATUS_OK) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ctx->profiler) {
|
||||
pds[i].opcode = ops[i].opcode;
|
||||
pds[i].usecs = prof.usecs;
|
||||
@@ -919,7 +977,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
|
||||
|
||||
struct htp_opbatch_rsp rsp;
|
||||
rsp.id = req.id;
|
||||
rsp.status = HTP_STATUS_OK;
|
||||
rsp.status = op_status;
|
||||
rsp.n_bufs = n_bufs;
|
||||
rsp.n_tensors = n_tens;
|
||||
rsp.n_ops = n_ops;
|
||||
|
||||
+2729
-4117
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,508 @@
|
||||
#ifndef HTP_MATMUL_OPS_H
|
||||
#define HTP_MATMUL_OPS_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include "htp-ops.h"
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// --- HMX Tile Constraints ---
|
||||
#define HTP_MM_HMX_TILE_N_COLS 32
|
||||
#define HTP_MM_HMX_TILE_N_ROWS 32
|
||||
#define HTP_MM_HMX_TILE_SIZE (32 * 32 * sizeof(__fp16)) // 2048 bytes
|
||||
#define HTP_MM_HMX_TILE_N_ELMS 1024
|
||||
#define HTP_MM_HMX_MIN_NROWS 4
|
||||
|
||||
// --- Weight Repacked Tile Sizes ---
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_0 576
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_Q4_1 640
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_Q8_0 1088
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_IQ4_NL 576
|
||||
#define HTP_MM_WEIGHT_TILE_SIZE_MXFP4 544
|
||||
|
||||
// --- Weight Repacked Aligned Tile Sizes ---
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0 640
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1 640
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0 1152
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_IQ4_NL 640
|
||||
#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4 640
|
||||
|
||||
// --- Activation Tiled Block Sizes (including padding) ---
|
||||
#define HTP_MM_ACT_TILE_SIZE_Q8_0 1152
|
||||
#define HTP_MM_ACT_TILE_SIZE_Q8_1 1280
|
||||
|
||||
#define HTP_MM_MAX_PREFETCH 16
|
||||
|
||||
// --- Solver Cost Model Penalty Weights (HMX-specific) ---
|
||||
#define HTP_MM_HMX_COST_W_DEQUANT 3 // cost penalty for quantized weight loading/dequantization
|
||||
#define HTP_MM_HMX_COST_A_CONVERT 2 // cost penalty for activation loading/conversion
|
||||
|
||||
// --- DMA Activation Transfer Configuration ---
|
||||
#define HTP_MM_DMA_ACT_ROWS_PER_STEP 2
|
||||
#define HTP_MM_DMA_ACT_MULTIPLIER 4
|
||||
|
||||
enum htp_mm_kernel_type {
|
||||
HTP_MM_KERNEL_UNSUPPORTED = 0,
|
||||
|
||||
// HMX paths
|
||||
HTP_MM_KERNEL_HMX_2D,
|
||||
HTP_MM_KERNEL_HMX_F16_BATCHED,
|
||||
|
||||
// HVX floating-point paths
|
||||
HTP_MM_KERNEL_HVX_F16_F16_VTCM,
|
||||
HTP_MM_KERNEL_HVX_F16_F16_DDR,
|
||||
HTP_MM_KERNEL_HVX_F16_F32_DDR,
|
||||
|
||||
HTP_MM_KERNEL_HVX_F32_F32_VTCM,
|
||||
HTP_MM_KERNEL_HVX_F32_F32_DDR,
|
||||
HTP_MM_KERNEL_HVX_F32_F16_DDR,
|
||||
|
||||
// HVX quantized paths
|
||||
HTP_MM_KERNEL_HVX_QUANT_ROW, // standard row-wise parallel quantization
|
||||
HTP_MM_KERNEL_HVX_QUANT_BLOCK, // parallel block-wise quantization
|
||||
HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT, // row-wise fallback flat quantization
|
||||
};
|
||||
|
||||
// Op-specific struct for precomputed matmul params
|
||||
struct htp_mm_kernel_params {
|
||||
int32_t kernel_type; // enum htp_mm_kernel_type
|
||||
int32_t pipeline; // 1 = pipelined execution, 0 = standard
|
||||
int32_t m_chunk; // Row chunk size (M chunk)
|
||||
int32_t n_chunk; // Col chunk size (N chunk)
|
||||
int32_t n_threads; // Number of threads to spawn
|
||||
int32_t n_act_threads; // Number of threads for activation preparation
|
||||
int32_t n_hmx; // 1 = use HMX, 0 = use HVX
|
||||
int32_t n_prefetch; // Prefetch lookahead buffers/rows in VTCM
|
||||
int32_t tile_size; // Weight tile size
|
||||
int32_t aligned_tile_size; // Aligned weight tile size (padded to 128)
|
||||
int32_t src1_row_size; // Row size for quantized activation
|
||||
int32_t vtcm_size; // Total required scratchpad size in VTCM
|
||||
int32_t vtcm_src0_size; // src0 scratchpad size in VTCM
|
||||
int32_t vtcm_src1_size; // src1 scratchpad size in VTCM
|
||||
int32_t vtcm_src2_size; // src2 scratchpad size in VTCM (fused only)
|
||||
int32_t vtcm_src3_size; // src3 scratchpad size in VTCM (fused only)
|
||||
int32_t vtcm_dst_size; // dst scratchpad size in VTCM
|
||||
|
||||
// Precomputed division values
|
||||
struct fastdiv_values div_ne12_ne1;
|
||||
struct fastdiv_values div_ne1;
|
||||
struct fastdiv_values div_r2;
|
||||
struct fastdiv_values div_r3;
|
||||
struct fastdiv_values div_ne11;
|
||||
};
|
||||
|
||||
#if defined(__cplusplus)
|
||||
static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
|
||||
#else
|
||||
_Static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob");
|
||||
#endif
|
||||
|
||||
struct mmid_row_mapping {
|
||||
uint32_t i1;
|
||||
uint32_t i2;
|
||||
};
|
||||
|
||||
// Search for optimal (mc, nc) chunk sizes within VTCM budget.
|
||||
static inline int htp_mm_hmx_compute_chunks(size_t vtcm_total,
|
||||
size_t overhead,
|
||||
size_t per_n_cost,
|
||||
size_t per_m_cost,
|
||||
size_t per_mn_cost,
|
||||
size_t m,
|
||||
size_t n,
|
||||
size_t m_block_cost,
|
||||
size_t n_block_cost,
|
||||
size_t * m_chunk_out,
|
||||
size_t * n_chunk_out,
|
||||
size_t * total_out) {
|
||||
if (m == 0 || n == 0) return -1;
|
||||
if (vtcm_total <= overhead) return -1;
|
||||
if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1;
|
||||
|
||||
const size_t usable = vtcm_total - overhead;
|
||||
|
||||
size_t best_cost = SIZE_MAX;
|
||||
size_t best_mn = 0;
|
||||
size_t best_m = 0, best_n = 0;
|
||||
|
||||
const size_t n_max = hex_align_down((size_t)n, HTP_MM_HMX_TILE_N_COLS);
|
||||
for (size_t nc = n_max; nc >= HTP_MM_HMX_TILE_N_COLS; nc -= HTP_MM_HMX_TILE_N_COLS) {
|
||||
size_t n_fixed = 0, ncmn = 0, mc_denom = 0;
|
||||
if (hex_mul_overflow(nc, per_n_cost, &n_fixed)) continue;
|
||||
if (n_fixed >= usable) goto next_nc;
|
||||
|
||||
if (hex_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc;
|
||||
if (hex_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc;
|
||||
|
||||
{
|
||||
size_t remain = usable - n_fixed;
|
||||
size_t mc = remain / mc_denom;
|
||||
mc = hex_align_down(mc, HTP_MM_HMX_TILE_N_ROWS);
|
||||
mc = hex_smin(mc, m);
|
||||
|
||||
if (mc == 0) {
|
||||
goto next_nc;
|
||||
}
|
||||
|
||||
size_t mblocks = ((size_t) m + mc - 1) / mc;
|
||||
size_t nblocks = ((size_t) n + nc - 1) / nc;
|
||||
size_t cost = mblocks * m_block_cost + nblocks * n_block_cost;
|
||||
size_t mn = mc * nc;
|
||||
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
|
||||
best_cost = cost;
|
||||
best_mn = mn;
|
||||
best_m = mc;
|
||||
best_n = nc;
|
||||
}
|
||||
}
|
||||
|
||||
next_nc:
|
||||
if (nc == HTP_MM_HMX_TILE_N_COLS) break; // avoid size_t underflow
|
||||
}
|
||||
|
||||
if (best_m == 0 || best_n == 0) return -1;
|
||||
|
||||
// Compute exact total (with overflow checks)
|
||||
size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0;
|
||||
if (hex_mul_overflow(best_n, per_n_cost, &t0)) return -1;
|
||||
if (hex_mul_overflow(best_m, per_m_cost, &t1)) return -1;
|
||||
if (hex_mul_overflow(best_m, best_n, &mn)) return -1;
|
||||
if (hex_mul_overflow(mn, per_mn_cost, &t2)) return -1;
|
||||
if (hex_add_overflow(t0, t1, &total)) return -1;
|
||||
if (hex_add_overflow(total, t2, &total)) return -1;
|
||||
if (hex_add_overflow(total, overhead, &total)) return -1;
|
||||
|
||||
*m_chunk_out = best_m;
|
||||
*n_chunk_out = best_n;
|
||||
*total_out = total;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// --- Tile Size Helpers ---
|
||||
static inline uint32_t htp_mm_get_weight_tile_size(int weight_type) {
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_Q4_0;
|
||||
case HTP_TYPE_Q4_1:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_Q4_1;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_Q8_0;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return HTP_MM_WEIGHT_TILE_SIZE_MXFP4;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static inline uint32_t htp_mm_get_weight_aligned_tile_size(int weight_type) {
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0;
|
||||
case HTP_TYPE_Q4_1:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1;
|
||||
case HTP_TYPE_Q8_0:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0;
|
||||
case HTP_TYPE_MXFP4:
|
||||
return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4;
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Activation/Row Size Helpers ---
|
||||
static inline size_t htp_mm_q8_0_tiled_row_size(uint32_t ne) {
|
||||
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
|
||||
const uint32_t nb_32 = ne_padded / 32;
|
||||
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_0;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_q8_1_tiled_row_size(uint32_t ne) {
|
||||
const uint32_t ne_padded = ((ne + 127) / 128) * 128;
|
||||
const uint32_t nb_32 = ne_padded / 32;
|
||||
return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_1;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_q8_0_flat_row_size(uint32_t ne) {
|
||||
const uint32_t quants_size = hex_align_up(ne, 128);
|
||||
const uint32_t num_scales = (ne + 31) / 32;
|
||||
const uint32_t scales_size = hex_align_up(num_scales * 2, 128);
|
||||
return quants_size + scales_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_q8_1_flat_row_size(uint32_t ne) {
|
||||
const uint32_t quants_size = hex_align_up(ne, 128);
|
||||
const uint32_t num_scales = (ne + 31) / 32;
|
||||
const uint32_t scales_size = hex_align_up(num_scales * 4, 128);
|
||||
return quants_size + scales_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_get_tiled_row_stride(int weight_type, uint32_t k) {
|
||||
uint32_t nb = (k + QK_Q4_0_TILED - 1) / QK_Q4_0_TILED;
|
||||
switch (weight_type) {
|
||||
case HTP_TYPE_Q4_0:
|
||||
case HTP_TYPE_IQ4_NL:
|
||||
case HTP_TYPE_Q4_1:
|
||||
case HTP_TYPE_Q8_0:
|
||||
case HTP_TYPE_MXFP4:
|
||||
return (size_t) nb * htp_mm_get_weight_tile_size(weight_type);
|
||||
case HTP_TYPE_F16:
|
||||
return (size_t) k * sizeof(__fp16);
|
||||
case HTP_TYPE_F32:
|
||||
return (size_t) k * sizeof(float);
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_round_up(size_t n, size_t m) {
|
||||
return ((n + m - 1) / m) * m;
|
||||
}
|
||||
|
||||
static inline bool htp_mm_hmx_pipeline(uint32_t m) {
|
||||
return m > 32;
|
||||
}
|
||||
|
||||
static inline void htp_mm_hmx_get_2d_chunk_costs(
|
||||
int wtype, uint32_t k, bool pipeline, uint32_t aligned_tile_size,
|
||||
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
|
||||
) {
|
||||
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
|
||||
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
|
||||
const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0;
|
||||
|
||||
*size_per_n_out = (pipeline ? 2 : 1) * (is_quant ? qweight_row_stride : row_stride) +
|
||||
(pipeline ? 2 * vec_dot_size : vec_dot_size);
|
||||
*size_per_m_out = vec_dot_size;
|
||||
*size_per_mn_out = (pipeline ? 2 : 1) * sizeof(uint16_t);
|
||||
}
|
||||
|
||||
static inline void htp_mm_hmx_get_batched_chunk_costs(
|
||||
uint32_t k, uint32_t group_size,
|
||||
size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out
|
||||
) {
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
*size_per_n_out = 3 * vec_dot_size;
|
||||
*size_per_m_out = group_size * vec_dot_size;
|
||||
*size_per_mn_out = sizeof(uint16_t);
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hmx_get_2d_vtcm_size(
|
||||
int wtype, uint32_t k, size_t mc, size_t nc, bool pipeline, uint32_t act_threads, uint32_t aligned_tile_size
|
||||
) {
|
||||
const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS;
|
||||
const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32);
|
||||
const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k);
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
|
||||
const size_t act_f32_size = htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE);
|
||||
size_t weight_area_size = is_quant
|
||||
? htp_mm_round_up((nc / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE)
|
||||
: htp_mm_round_up(nc * row_stride, HTP_MM_HMX_TILE_SIZE);
|
||||
if (pipeline) {
|
||||
weight_area_size *= 2;
|
||||
}
|
||||
const size_t act_area_size = htp_mm_round_up(mc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t output_area_size = htp_mm_round_up(mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
|
||||
|
||||
size_t scratch0_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
size_t scratch1_size = pipeline ? scratch0_size : 0;
|
||||
size_t scratch2_size = pipeline ? output_area_size : 0;
|
||||
|
||||
return weight_area_size + act_area_size + act_f32_size + output_area_size +
|
||||
scratch0_size + scratch1_size + scratch2_size + 256;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hmx_get_batched_vtcm_size(
|
||||
int wtype, uint32_t k, size_t mc, size_t nc, uint32_t group_size, bool use_dma_activation, bool pipeline, uint32_t act_threads) {
|
||||
(void)wtype;
|
||||
(void)pipeline;
|
||||
const size_t vec_dot_size = k * sizeof(uint16_t);
|
||||
const size_t f32_scratch_size = use_dma_activation
|
||||
? htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0;
|
||||
|
||||
const size_t act_head_stride = mc * k;
|
||||
const size_t weight_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t act_area_size = htp_mm_round_up(group_size * act_head_stride * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t output_area_size = htp_mm_round_up(group_size * mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE);
|
||||
const size_t scratch_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE);
|
||||
|
||||
return weight_area_size + act_area_size + output_area_size +
|
||||
2 * scratch_area_size + 256 + f32_scratch_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hvx_get_vtcm_sizes(
|
||||
int kernel_type,
|
||||
int wtype,
|
||||
uint32_t ne10, // k
|
||||
uint32_t src1_nrows, // m_total (or act_nrows)
|
||||
uint32_t n_threads,
|
||||
size_t dst_row_size,
|
||||
size_t src0_row_size,
|
||||
size_t src1_row_size,
|
||||
uint32_t n_prefetch,
|
||||
size_t * vtcm_src0_size_out,
|
||||
size_t * vtcm_src1_size_out,
|
||||
size_t * vtcm_dst_size_out
|
||||
) {
|
||||
size_t vtcm_src0_size = 0;
|
||||
size_t vtcm_src1_size = 0;
|
||||
size_t vtcm_dst_size = 0;
|
||||
|
||||
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
|
||||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
|
||||
wtype == HTP_TYPE_MXFP4);
|
||||
|
||||
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
|
||||
const size_t dst_nrows = (src1_nrows > 1) ? 0 : 1;
|
||||
|
||||
switch (kernel_type) {
|
||||
case HTP_MM_KERNEL_HVX_F16_F16_VTCM: {
|
||||
size_t f16_src1_row_size = htp_mm_round_up(ne10 * 2, 128);
|
||||
vtcm_src1_size = htp_mm_round_up(f16_src1_row_size * src1_nrows, 256);
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_F16_F32_DDR:
|
||||
case HTP_MM_KERNEL_HVX_F16_F16_DDR:
|
||||
case HTP_MM_KERNEL_HVX_F32_F32_DDR:
|
||||
case HTP_MM_KERNEL_HVX_F32_F16_DDR: {
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size, 256) * n_threads;
|
||||
vtcm_src1_size = htp_mm_round_up(n_prefetch * src1_row_size, 256) * n_threads;
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_F32_F32_VTCM: {
|
||||
size_t f32_src1_row_size = htp_mm_round_up(ne10 * 4, 128);
|
||||
vtcm_src1_size = htp_mm_round_up(f32_src1_row_size * src1_nrows, 256);
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads;
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_QUANT_BLOCK:
|
||||
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
|
||||
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
|
||||
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
if (vtcm_src0_size < src1_row_size_padded) {
|
||||
vtcm_src0_size = src1_row_size_padded;
|
||||
}
|
||||
|
||||
vtcm_src0_size = vtcm_src0_size * n_threads;
|
||||
vtcm_dst_size = vtcm_dst_size * n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = ne10 / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
vtcm_src0_size = repacked_vtcm_size * n_threads;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
|
||||
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
|
||||
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
|
||||
|
||||
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
|
||||
if (vtcm_src0_size < src1_row_size_padded) {
|
||||
vtcm_src0_size = src1_row_size_padded;
|
||||
}
|
||||
|
||||
vtcm_src0_size = vtcm_src0_size * n_threads;
|
||||
vtcm_dst_size = vtcm_dst_size * n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = ne10 / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
vtcm_src0_size = repacked_vtcm_size * n_threads;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
*vtcm_src0_size_out = vtcm_src0_size;
|
||||
*vtcm_src1_size_out = vtcm_src1_size;
|
||||
*vtcm_dst_size_out = vtcm_dst_size;
|
||||
|
||||
return vtcm_src0_size + vtcm_src1_size + vtcm_dst_size;
|
||||
}
|
||||
|
||||
static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
|
||||
int wtype,
|
||||
uint32_t ne10, // k
|
||||
uint32_t src1_nrows,
|
||||
uint32_t n_threads,
|
||||
size_t src0_row_size, // nb01
|
||||
uint32_t n_prefetch,
|
||||
size_t * vtcm_src0_size_out,
|
||||
size_t * vtcm_src1_size_out
|
||||
) {
|
||||
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
|
||||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
|
||||
wtype == HTP_TYPE_MXFP4);
|
||||
|
||||
const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128);
|
||||
const size_t src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10)
|
||||
: htp_mm_q8_0_tiled_row_size(ne10);
|
||||
|
||||
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
|
||||
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
if (src0_sz_per_thread < src1_row_size_padded) {
|
||||
src0_sz_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
||||
if (is_repack) {
|
||||
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
const uint32_t n_k_tiles = ne10 / 32;
|
||||
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
src0_sz_per_thread = repacked_vtcm_size;
|
||||
}
|
||||
|
||||
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
|
||||
|
||||
*vtcm_src0_size_out = vtcm_src0_size;
|
||||
*vtcm_src1_size_out = src1_sz;
|
||||
|
||||
return vtcm_src0_size + src1_sz;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // HTP_MATMUL_OPS_H
|
||||
@@ -14,8 +14,6 @@ Drivers_Dir = 13
|
||||
1 = %DiskId%
|
||||
|
||||
[SourceDisksFiles]
|
||||
libggml-htp-v68.so = 1
|
||||
libggml-htp-v69.so = 1
|
||||
libggml-htp-v73.so = 1
|
||||
libggml-htp-v75.so = 1
|
||||
libggml-htp-v79.so = 1
|
||||
@@ -28,8 +26,6 @@ ExcludeFromSelect = *
|
||||
CopyFiles=Drivers_Dir
|
||||
|
||||
[Drivers_Dir]
|
||||
libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE
|
||||
|
||||
@@ -192,7 +192,10 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mm_f16_f32_kq_kqv
|
||||
conv2d
|
||||
conv2d_f16_f32
|
||||
flash_attn_pre_f16
|
||||
flash_attn_f32_f16
|
||||
flash_attn_f32_q8_0
|
||||
flash_attn_f32_q4_0
|
||||
flash_attn_f16
|
||||
flash_attn_f32
|
||||
)
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
#pragma once
|
||||
|
||||
// Flash-attention per-(dk,dv) tile tuning for the Adreno OpenCL backend.
|
||||
// Isolated from ggml-opencl.cpp so the tuning numbers are easy to find and
|
||||
// edit; the FA dispatch and kernel-compile logic stay in the main file.
|
||||
// This header is a file section — it is #included exactly once, at the point
|
||||
// in ggml-opencl.cpp where the ggml logging macros are already in scope.
|
||||
|
||||
// Per-(dk, dv) FA config; shared by dispatch and supports_op.
|
||||
struct ggml_opencl_fa_dim {
|
||||
int dk; int dv; int bm; int bn; int n_split; int nkv_split_threshold;
|
||||
};
|
||||
|
||||
// Split variant fires when n_kv >= threshold (threshold=0 -> always split).
|
||||
// Default tuning covers Adreno 7xx/8xx mobile and X1-series laptop GPUs.
|
||||
static const ggml_opencl_fa_dim g_fa_dims_adreno_default[] = {
|
||||
{ 40, 40, 64, 32, 1, 0}, { 64, 64, 64, 32, 2, 64},
|
||||
{ 80, 80, 64, 32, 2, 64}, { 96, 96, 64, 32, 2, 64},
|
||||
{112, 112, 64, 32, 2, 64}, {128, 128, 64, 32, 2, 64},
|
||||
{192, 128, 16, 16, 1, 0},
|
||||
{192, 192, 16, 16, 1, 0},
|
||||
{256, 256, 16, 16, 16, 0},
|
||||
};
|
||||
|
||||
struct ggml_opencl_fa_dim_table {
|
||||
const ggml_opencl_fa_dim * data;
|
||||
size_t count;
|
||||
|
||||
const ggml_opencl_fa_dim * begin() const { return data; }
|
||||
const ggml_opencl_fa_dim * end() const { return data + count; }
|
||||
};
|
||||
|
||||
// Mutable copy of the active table; GGML_OPENCL_FA_TUNE patches entries here
|
||||
// at backend init without touching the const source table.
|
||||
static ggml_opencl_fa_dim g_fa_dims_runtime[
|
||||
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0])];
|
||||
|
||||
static ggml_opencl_fa_dim_table g_opencl_fa_dims = {
|
||||
g_fa_dims_adreno_default,
|
||||
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]),
|
||||
};
|
||||
|
||||
// GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,…] — patches matching entries
|
||||
// in the active table at backend init, before the first FA kernel compiles.
|
||||
// Unmatched (dk,dv) pairs are warned and ignored.
|
||||
static void ggml_opencl_fa_apply_env_overrides() {
|
||||
const char * e = std::getenv("GGML_OPENCL_FA_TUNE");
|
||||
if (!e || !e[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string s = e;
|
||||
size_t pos = 0;
|
||||
while (pos < s.size()) {
|
||||
size_t comma = s.find(',', pos);
|
||||
std::string entry = s.substr(pos, comma == std::string::npos ? std::string::npos : comma - pos);
|
||||
int dk, dv, bm, bn, nsplit, thr;
|
||||
if (std::sscanf(entry.c_str(), "%d:%d:%d:%d:%d:%d", &dk, &dv, &bm, &bn, &nsplit, &thr) == 6) {
|
||||
bool patched = false;
|
||||
for (size_t i = 0; i < g_opencl_fa_dims.count; ++i) {
|
||||
ggml_opencl_fa_dim & d = g_fa_dims_runtime[i];
|
||||
if (d.dk == dk && d.dv == dv) {
|
||||
d.bm = bm; d.bn = bn; d.n_split = nsplit; d.nkv_split_threshold = thr;
|
||||
GGML_LOG_INFO("ggml_opencl: FA tune override DK=%d DV=%d -> bm=%d bn=%d n_split=%d thr=%d\n",
|
||||
dk, dv, bm, bn, nsplit, thr);
|
||||
patched = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!patched) {
|
||||
GGML_LOG_WARN("ggml_opencl: FA tune override DK=%d DV=%d ignored (no matching dim)\n", dk, dv);
|
||||
}
|
||||
} else {
|
||||
GGML_LOG_WARN("ggml_opencl: FA tune override entry malformed: '%s'\n", entry.c_str());
|
||||
}
|
||||
if (comma == std::string::npos) break;
|
||||
pos = comma + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the default table into the mutable runtime buffer and apply any
|
||||
// GGML_OPENCL_FA_TUNE overrides. A per-generation table can be added here
|
||||
// once it has been tuned on hardware.
|
||||
static void ggml_cl_init_fa_dims_table() {
|
||||
const size_t count = sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
g_fa_dims_runtime[i] = g_fa_dims_adreno_default[i];
|
||||
}
|
||||
g_opencl_fa_dims = { g_fa_dims_runtime, count };
|
||||
ggml_opencl_fa_apply_env_overrides();
|
||||
}
|
||||
+1796
-265
File diff suppressed because it is too large
Load Diff
@@ -1582,6 +1582,158 @@ kernel void kernel_restore_block_q8_0(
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q8_0 -> f32 dequant (f32/f32 FA path).
|
||||
kernel void kernel_dequant_q8_0_f32_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global float * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global char * qs = block + 2;
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global float * out = dst + (dst_row_base + blk_i0) * QK8_0;
|
||||
|
||||
for (int i = 0; i < QK8_0; ++i) {
|
||||
out[i] = d * (float)qs[i];
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q8_0 -> f16 dequant. Rows tight, batch strides may be gapped.
|
||||
kernel void kernel_dequant_q8_0_f16_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global half * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global char * qs = block + 2;
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global half * out = dst + (dst_row_base + blk_i0) * QK8_0;
|
||||
|
||||
for (int i = 0; i < QK8_0; ++i) {
|
||||
out[i] = (half)(d * (float)qs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q4_0 -> f32 dequant (mirrors the q8_0 view variant).
|
||||
kernel void kernel_dequant_q4_0_f32_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global float * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global uchar * qs = (global uchar *)(block + 2);
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global float * out = dst + (dst_row_base + blk_i0) * QK4_0;
|
||||
|
||||
for (int i = 0; i < QK4_0/2; ++i) {
|
||||
uchar byte = qs[i];
|
||||
int q0 = (int)(byte & 0x0F) - 8;
|
||||
int q1 = (int)(byte >> 4) - 8;
|
||||
out[i] = d * (float)q0;
|
||||
out[i + QK4_0/2] = d * (float)q1;
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q4_0 -> f16 dequant (mirrors the q8_0 view variant).
|
||||
kernel void kernel_dequant_q4_0_f16_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global half * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global uchar * qs = (global uchar *)(block + 2);
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global half * out = dst + (dst_row_base + blk_i0) * QK4_0;
|
||||
|
||||
for (int i = 0; i < QK4_0/2; ++i) {
|
||||
uchar byte = qs[i];
|
||||
int q0 = (int)(byte & 0x0F) - 8;
|
||||
int q1 = (int)(byte >> 4) - 8;
|
||||
out[i] = (half)(d * (float)q0);
|
||||
out[i + QK4_0/2] = (half)(d * (float)q1);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q8_0_trans(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
|
||||
@@ -4,14 +4,26 @@
|
||||
#define ACC_TYPE4 float4
|
||||
#define DATA_TYPE half
|
||||
#define DATA_TYPE4 half4
|
||||
#define CONVERT_ACC4(x) convert_float4(x)
|
||||
#define CONVERT_DATA4(x) convert_half4(x)
|
||||
#define CONVERT_ACC4(x) ((float4)((float)(x).s0, (float)(x).s1, (float)(x).s2, (float)(x).s3))
|
||||
#define CONVERT_DATA4(x) ((half4)((half)(x).s0, (half)(x).s1, (half)(x).s2, (half)(x).s3))
|
||||
|
||||
#define DK_VEC (DK/4)
|
||||
#define DV_VEC (DV/4)
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#define Q1_WG_SIZE 64
|
||||
|
||||
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
|
||||
// infinite operand can cause undefined behavior and miscompilation for exp.
|
||||
// Therefore, a large negative value is used instead.
|
||||
#define FA_M_INIT (-3.0e38f)
|
||||
|
||||
// Drop full unroll at DK>=192 — Adreno compiler host-memory budget.
|
||||
#if DK >= 192
|
||||
#define FA_UNROLL
|
||||
#else
|
||||
#define FA_UNROLL _Pragma("unroll")
|
||||
#endif
|
||||
|
||||
inline float get_alibi_slope(
|
||||
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
||||
) {
|
||||
@@ -81,18 +93,18 @@ __kernel void flash_attn_f16(
|
||||
if (my_query_row < n_q) {
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
@@ -125,49 +137,72 @@ __kernel void flash_attn_f16(
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
for (int j = 0; j < BLOCK_N; j += 4) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
const int k_row2 = k_start + j + 2;
|
||||
const int k_row3 = k_start + j + 3;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
|
||||
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
|
||||
}
|
||||
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
|
||||
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
||||
const int causal_limit = n_kv - n_q + my_query_row;
|
||||
if (k_row0 > causal_limit) s0 = FA_M_INIT;
|
||||
if (k_row1 > causal_limit) s1 = FA_M_INIT;
|
||||
if (k_row2 > causal_limit) s2 = FA_M_INIT;
|
||||
if (k_row3 > causal_limit) s3 = FA_M_INIT;
|
||||
}
|
||||
|
||||
if (k_row0 >= n_kv) score0 = -INFINITY;
|
||||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||
if (k_row0 >= n_kv) s0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) s1 = FA_M_INIT;
|
||||
if (k_row2 >= n_kv) s2 = FA_M_INIT;
|
||||
if (k_row3 >= n_kv) s3 = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL) {
|
||||
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
|
||||
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
s0 = logit_softcap * tanh(s0 / logit_softcap);
|
||||
s1 = logit_softcap * tanh(s1 / logit_softcap);
|
||||
s2 = logit_softcap * tanh(s2 / logit_softcap);
|
||||
s3 = logit_softcap * tanh(s3 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
const ACC_TYPE p0 = exp(score0 - m_new);
|
||||
const ACC_TYPE p1 = exp(score1 - m_new);
|
||||
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
||||
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
|
||||
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
|
||||
const ACC_TYPE p0 = native_exp(s0 - m_new);
|
||||
const ACC_TYPE p1 = native_exp(s1 - m_new);
|
||||
const ACC_TYPE p2 = native_exp(s2 - m_new);
|
||||
const ACC_TYPE p3 = native_exp(s3 - m_new);
|
||||
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
|
||||
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
|
||||
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
|
||||
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
|
||||
mad(p0, CONVERT_ACC4(l_v[j][i]),
|
||||
o_acc[i] * scale_prev))));
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1;
|
||||
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
@@ -179,7 +214,7 @@ __kernel void flash_attn_f16(
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
@@ -191,12 +226,12 @@ __kernel void flash_attn_f16(
|
||||
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_i;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
@@ -258,7 +293,7 @@ __kernel void flash_attn_f16_q1(
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
@@ -270,12 +305,12 @@ __kernel void flash_attn_f16_q1(
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -293,7 +328,7 @@ __kernel void flash_attn_f16_q1(
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -301,7 +336,7 @@ __kernel void flash_attn_f16_q1(
|
||||
const ACC_TYPE m_final = local_m[0];
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
@@ -311,7 +346,7 @@ __kernel void flash_attn_f16_q1(
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -325,7 +360,7 @@ __kernel void flash_attn_f16_q1(
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_final);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
@@ -335,7 +370,7 @@ __kernel void flash_attn_f16_q1(
|
||||
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -354,7 +389,7 @@ __kernel void flash_attn_f16_q1(
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
local_o_comp[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -364,7 +399,7 @@ __kernel void flash_attn_f16_q1(
|
||||
}
|
||||
}
|
||||
} else if (tid == 0) {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,18 @@
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#define Q1_WG_SIZE 64
|
||||
|
||||
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
|
||||
// infinite operand can cause undefined behavior and miscompilation for exp.
|
||||
// Therefore, a large negative value is used instead.
|
||||
#define FA_M_INIT (-3.0e38f)
|
||||
|
||||
// Drop full unroll at DK>=192 — Adreno compiler host-memory budget.
|
||||
#if DK >= 192
|
||||
#define FA_UNROLL
|
||||
#else
|
||||
#define FA_UNROLL _Pragma("unroll")
|
||||
#endif
|
||||
|
||||
inline float get_alibi_slope(
|
||||
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
||||
) {
|
||||
@@ -82,18 +94,18 @@ __kernel void flash_attn_f32(
|
||||
if (my_query_row < n_q) {
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
@@ -126,49 +138,72 @@ __kernel void flash_attn_f32(
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
for (int j = 0; j < BLOCK_N; j += 4) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
const int k_row2 = k_start + j + 2;
|
||||
const int k_row3 = k_start + j + 3;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
|
||||
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
|
||||
}
|
||||
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
|
||||
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
||||
const int causal_limit = n_kv - n_q + my_query_row;
|
||||
if (k_row0 > causal_limit) s0 = FA_M_INIT;
|
||||
if (k_row1 > causal_limit) s1 = FA_M_INIT;
|
||||
if (k_row2 > causal_limit) s2 = FA_M_INIT;
|
||||
if (k_row3 > causal_limit) s3 = FA_M_INIT;
|
||||
}
|
||||
|
||||
if (k_row0 >= n_kv) score0 = -INFINITY;
|
||||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||
if (k_row0 >= n_kv) s0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) s1 = FA_M_INIT;
|
||||
if (k_row2 >= n_kv) s2 = FA_M_INIT;
|
||||
if (k_row3 >= n_kv) s3 = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
|
||||
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
s0 = logit_softcap * tanh(s0 / logit_softcap);
|
||||
s1 = logit_softcap * tanh(s1 / logit_softcap);
|
||||
s2 = logit_softcap * tanh(s2 / logit_softcap);
|
||||
s3 = logit_softcap * tanh(s3 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
const ACC_TYPE p0 = exp(score0 - m_new);
|
||||
const ACC_TYPE p1 = exp(score1 - m_new);
|
||||
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
||||
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
|
||||
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
|
||||
const ACC_TYPE p0 = native_exp(s0 - m_new);
|
||||
const ACC_TYPE p1 = native_exp(s1 - m_new);
|
||||
const ACC_TYPE p2 = native_exp(s2 - m_new);
|
||||
const ACC_TYPE p3 = native_exp(s3 - m_new);
|
||||
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
|
||||
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
|
||||
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
|
||||
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
|
||||
mad(p0, CONVERT_ACC4(l_v[j][i]),
|
||||
o_acc[i] * scale_prev))));
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1;
|
||||
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
@@ -180,7 +215,7 @@ __kernel void flash_attn_f32(
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
@@ -192,12 +227,12 @@ __kernel void flash_attn_f32(
|
||||
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_i;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
@@ -259,7 +294,7 @@ __kernel void flash_attn_f32_q1(
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
@@ -271,12 +306,12 @@ __kernel void flash_attn_f32_q1(
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -294,7 +329,7 @@ __kernel void flash_attn_f32_q1(
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -302,7 +337,7 @@ __kernel void flash_attn_f32_q1(
|
||||
const ACC_TYPE m_final = local_m[0];
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
@@ -312,7 +347,7 @@ __kernel void flash_attn_f32_q1(
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -326,7 +361,7 @@ __kernel void flash_attn_f32_q1(
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_final);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
@@ -336,7 +371,7 @@ __kernel void flash_attn_f32_q1(
|
||||
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -355,7 +390,7 @@ __kernel void flash_attn_f32_q1(
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
local_o_comp[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -365,7 +400,7 @@ __kernel void flash_attn_f32_q1(
|
||||
}
|
||||
}
|
||||
} else if (tid == 0) {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_khr_subgroup_shuffle
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable
|
||||
#define HAS_SUBGROUP_SHUFFLE 1
|
||||
#elif defined(cl_qcom_subgroup_shuffle)
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
|
||||
#define HAS_SUBGROUP_SHUFFLE 1
|
||||
#endif
|
||||
|
||||
#define ACC_TYPE float
|
||||
#define ACC_TYPE4 float4
|
||||
#define Q_DATA_TYPE4 float4
|
||||
@@ -12,9 +20,34 @@
|
||||
|
||||
#define DK_VEC (DK/4)
|
||||
#define DV_VEC (DV/4)
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#define Q1_WG_SIZE 64
|
||||
|
||||
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
|
||||
// infinite operand can cause undefined behavior and miscompilation for exp.
|
||||
// Therefore, a large negative value is used instead.
|
||||
#define FA_M_INIT (-3.0e38f)
|
||||
|
||||
// Drop full unroll at DK>=192 — Adreno compiler host-memory budget.
|
||||
#if DK >= 192
|
||||
#define FA_UNROLL
|
||||
#else
|
||||
#define FA_UNROLL _Pragma("unroll")
|
||||
#endif
|
||||
|
||||
// N_SPLIT>1 splits DK/DV across threads to cut per-thread register use.
|
||||
#ifndef N_SPLIT
|
||||
#define N_SPLIT 1
|
||||
#endif
|
||||
|
||||
#define SPLIT_DK_VEC (DK_VEC / N_SPLIT)
|
||||
#define SPLIT_DV_VEC (DV_VEC / N_SPLIT)
|
||||
|
||||
#if N_SPLIT > 1
|
||||
#define WG_SIZE (BLOCK_M * N_SPLIT)
|
||||
#else
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#endif
|
||||
|
||||
inline float get_alibi_slope(
|
||||
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
||||
) {
|
||||
@@ -54,19 +87,38 @@ __kernel void flash_attn_f32_f16(
|
||||
const int mask_ne2,
|
||||
const int mask_ne3,
|
||||
const global void* sinks_void,
|
||||
const ulong sinks_offset
|
||||
const ulong sinks_offset,
|
||||
const global void * k_pad_void,
|
||||
const global void * v_pad_void,
|
||||
const global void * mask_pad_void,
|
||||
const global char * blk,
|
||||
const int n_kv_blocks,
|
||||
const ulong mask_pad_nb1,
|
||||
const ulong mask_pad_nb2,
|
||||
const ulong mask_pad_nb3
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int block_q_idx = get_group_id(0);
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
|
||||
const int my_query_row = block_q_idx * BLOCK_M + tid;
|
||||
#if N_SPLIT > 1
|
||||
const int q_lane = tid / N_SPLIT;
|
||||
const int split_idx = tid % N_SPLIT;
|
||||
#else
|
||||
const int q_lane = tid;
|
||||
const int split_idx = 0;
|
||||
#endif
|
||||
|
||||
const int my_query_row = block_q_idx * BLOCK_M + q_lane;
|
||||
const int query_valid = my_query_row < n_q;
|
||||
|
||||
const int batch_idx = head_batch_idx / n_head;
|
||||
const int head_idx = head_batch_idx % n_head;
|
||||
|
||||
const int gqa_ratio = n_head / n_head_kv;
|
||||
const int head_kv_idx = head_idx / gqa_ratio;
|
||||
const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0;
|
||||
const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0;
|
||||
|
||||
const global char* q_base = (const global char*)q_void + q_offset;
|
||||
const global char* k_base = (const global char*)k_void + k_offset;
|
||||
@@ -75,27 +127,41 @@ __kernel void flash_attn_f32_f16(
|
||||
|
||||
const global char* mask_base = NULL;
|
||||
if (mask_void != NULL) {
|
||||
const int mask_head_idx = head_idx % mask_ne2;
|
||||
const int mask_batch_idx = batch_idx % mask_ne3;
|
||||
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
||||
}
|
||||
const global char* mask_pad_base = NULL;
|
||||
if (mask_pad_void != NULL) {
|
||||
mask_pad_base = (const global char*)mask_pad_void + mask_batch_idx * mask_pad_nb3 + mask_head_idx * mask_pad_nb2;
|
||||
}
|
||||
const global char* blk_base = NULL;
|
||||
if (blk != NULL) {
|
||||
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
|
||||
blk_base = blk + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks + block_q_idx) * n_kv_blocks;
|
||||
}
|
||||
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
if (my_query_row < n_q) {
|
||||
ACC_TYPE4 q_priv[SPLIT_DK_VEC];
|
||||
const int dk_off = split_idx * SPLIT_DK_VEC;
|
||||
if (query_valid) {
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
||||
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_Q_ACC4(q_ptr[dk_off + i]);
|
||||
}
|
||||
} else {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
|
||||
q_priv[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
ACC_TYPE4 o_acc[SPLIT_DV_VEC];
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
@@ -103,86 +169,369 @@ __kernel void flash_attn_f32_f16(
|
||||
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
|
||||
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
|
||||
|
||||
#if N_SPLIT > 1 && !defined(HAS_SUBGROUP_SHUFFLE)
|
||||
__local ACC_TYPE local_partial[BLOCK_N][WG_SIZE];
|
||||
__local ACC_TYPE local_p[BLOCK_M][BLOCK_N];
|
||||
__local ACC_TYPE local_softmax_scale[BLOCK_M];
|
||||
__local ACC_TYPE local_l_inv[BLOCK_M];
|
||||
#endif
|
||||
|
||||
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
|
||||
char blk_cur = 1;
|
||||
if (blk_base != NULL) {
|
||||
blk_cur = blk_base[k_start / BLOCK_N];
|
||||
if (blk_cur == 0) continue;
|
||||
}
|
||||
|
||||
const int use_kv_pad = k_pad_void != NULL && k_start + BLOCK_N > n_kv;
|
||||
const int k_tile_start = use_kv_pad ? 0 : k_start;
|
||||
const ulong k_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * k_nb1 : k_nb2;
|
||||
const ulong k_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * k_tile_nb2 : k_nb3;
|
||||
const ulong v_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * v_nb1 : v_nb2;
|
||||
const ulong v_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * v_tile_nb2 : v_nb3;
|
||||
const global char* k_tile_base = use_kv_pad ? (const global char*) k_pad_void : k_base;
|
||||
const global char* v_tile_base = use_kv_pad ? (const global char*) v_pad_void : v_base;
|
||||
|
||||
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
|
||||
const int row = i / DK_VEC;
|
||||
const int col = i % DK_VEC;
|
||||
const int k_row_idx = k_start + row;
|
||||
if (k_row_idx < n_kv) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
|
||||
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
|
||||
const int k_row_idx = k_tile_start + row;
|
||||
if (use_kv_pad || k_row_idx < n_kv) {
|
||||
const ulong k_row_offset = batch_idx * k_tile_nb3 + head_kv_idx * k_tile_nb2 + k_row_idx * k_nb1;
|
||||
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_tile_base + k_row_offset))[col];
|
||||
} else {
|
||||
l_k[row][col] = (KV_DATA_TYPE4)(0.0h);
|
||||
}
|
||||
}
|
||||
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
|
||||
const int row = i / DV_VEC;
|
||||
const int col = i % DV_VEC;
|
||||
const int v_row_idx = k_start + row;
|
||||
if (v_row_idx < n_kv) {
|
||||
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
|
||||
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
|
||||
const int v_row_idx = k_tile_start + row;
|
||||
if (use_kv_pad || v_row_idx < n_kv) {
|
||||
const ulong v_row_offset = batch_idx * v_tile_nb3 + head_kv_idx * v_tile_nb2 + v_row_idx * v_nb1;
|
||||
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_tile_base + v_row_offset))[col];
|
||||
} else {
|
||||
l_v[row][col] = (KV_DATA_TYPE4)(0.0h);
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (my_query_row >= n_q) {
|
||||
continue;
|
||||
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
|
||||
{
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
|
||||
ACC_TYPE partial0 = 0.0f;
|
||||
ACC_TYPE partial1 = 0.0f;
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < SPLIT_DK_VEC; k++) {
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
ACC_TYPE4 dot0 = qk * CONVERT_KV_ACC4(l_k[j ][dk_off + k]);
|
||||
ACC_TYPE4 dot1 = qk * CONVERT_KV_ACC4(l_k[j+1][dk_off + k]);
|
||||
partial0 += dot0.s0 + dot0.s1 + dot0.s2 + dot0.s3;
|
||||
partial1 += dot1.s0 + dot1.s1 + dot1.s2 + dot1.s3;
|
||||
}
|
||||
|
||||
FA_UNROLL
|
||||
for (int step = 1; step < N_SPLIT; step <<= 1) {
|
||||
partial0 += sub_group_shuffle_xor(partial0, step);
|
||||
partial1 += sub_group_shuffle_xor(partial1, step);
|
||||
}
|
||||
|
||||
ACC_TYPE score0 = partial0 * scale;
|
||||
ACC_TYPE score1 = partial1 * scale;
|
||||
|
||||
if (!query_valid) { score0 = FA_M_INIT; score1 = FA_M_INIT; }
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = FA_M_INIT;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = FA_M_INIT;
|
||||
}
|
||||
if (k_row0 >= n_kv) score0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) score1 = FA_M_INIT;
|
||||
|
||||
if (query_valid && mask_base != NULL && blk_cur != 2) {
|
||||
if (use_kv_pad && mask_pad_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
|
||||
score0 += slope * (ACC_TYPE)mask_ptr[j];
|
||||
score1 += slope * (ACC_TYPE)mask_ptr[j + 1];
|
||||
} else {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
|
||||
// far negative so the tile contributes 0, not exp(0)=1.
|
||||
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
|
||||
const ACC_TYPE sp = native_exp(m_i - m_exp);
|
||||
const ACC_TYPE p0 = native_exp(score0 - m_exp);
|
||||
const ACC_TYPE p1 = native_exp(score1 - m_exp);
|
||||
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * sp
|
||||
+ p0 * CONVERT_KV_ACC4(l_v[j ][dv_off + i])
|
||||
+ p1 * CONVERT_KV_ACC4(l_v[j+1][dv_off + i]);
|
||||
}
|
||||
l_i = l_i * sp + p0 + p1;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
#elif N_SPLIT > 1
|
||||
// N_SPLIT>1 fallback (no shuffle): 3-phase local-memory reduction.
|
||||
// Phase 1 — partial dots for all BLOCK_N tokens.
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < SPLIT_DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][dk_off + k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
||||
}
|
||||
|
||||
if (k_row0 >= n_kv) score0 = -INFINITY;
|
||||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
const ACC_TYPE p0 = exp(score0 - m_new);
|
||||
const ACC_TYPE p1 = exp(score1 - m_new);
|
||||
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1;
|
||||
m_i = m_new;
|
||||
local_partial[j][tid] =
|
||||
dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3;
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE); // 1 barrier: partial dots visible
|
||||
|
||||
// Phase 2 — split_idx==0 reduces partial sums and computes block softmax.
|
||||
if (split_idx == 0) {
|
||||
if (query_valid) {
|
||||
ACC_TYPE m_new = m_i;
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
const int k_row = k_start + j;
|
||||
ACC_TYPE score = 0.0f;
|
||||
FA_UNROLL
|
||||
for (int s = 0; s < N_SPLIT; s++) {
|
||||
score += local_partial[j][q_lane * N_SPLIT + s];
|
||||
}
|
||||
score *= scale;
|
||||
|
||||
if (is_causal && k_row > (n_kv - n_q + my_query_row)) score = FA_M_INIT;
|
||||
if (k_row >= n_kv) score = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL && blk_cur != 2) {
|
||||
if (use_kv_pad && mask_pad_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
|
||||
score += slope * (ACC_TYPE)mask_ptr[j];
|
||||
} else {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row < n_kv) score += slope * (ACC_TYPE)mask_ptr[k_row];
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score = logit_softcap * tanh(score / logit_softcap);
|
||||
}
|
||||
|
||||
m_new = max(m_new, score);
|
||||
local_p[q_lane][j] = score;
|
||||
}
|
||||
|
||||
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
|
||||
const ACC_TYPE sp = native_exp(m_i - m_exp);
|
||||
ACC_TYPE l_new = l_i * sp;
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
const ACC_TYPE p = native_exp(local_p[q_lane][j] - m_exp);
|
||||
local_p[q_lane][j] = p;
|
||||
l_new += p;
|
||||
}
|
||||
local_softmax_scale[q_lane] = sp;
|
||||
l_i = l_new;
|
||||
m_i = m_new;
|
||||
} else {
|
||||
local_softmax_scale[q_lane] = 1.0f;
|
||||
for (int j = 0; j < BLOCK_N; ++j) local_p[q_lane][j] = 0.0f;
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Phase 3 — V accumulate using broadcast probabilities.
|
||||
{
|
||||
const ACC_TYPE sp_block = local_softmax_scale[q_lane];
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] *= sp_block;
|
||||
}
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
const ACC_TYPE p = local_p[q_lane][j];
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] = mad(p, CONVERT_KV_ACC4(l_v[j][dv_off + i]), o_acc[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// N_SPLIT==1: j+=4 unroll. Requires BLOCK_N % 4 == 0.
|
||||
if (query_valid) {
|
||||
for (int j = 0; j < BLOCK_N; j += 4) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
const int k_row2 = k_start + j + 2;
|
||||
const int k_row3 = k_start + j + 3;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k]), dot_acc2);
|
||||
dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k]), dot_acc3);
|
||||
}
|
||||
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
|
||||
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
const int causal_limit = n_kv - n_q + my_query_row;
|
||||
if (k_row0 > causal_limit) s0 = FA_M_INIT;
|
||||
if (k_row1 > causal_limit) s1 = FA_M_INIT;
|
||||
if (k_row2 > causal_limit) s2 = FA_M_INIT;
|
||||
if (k_row3 > causal_limit) s3 = FA_M_INIT;
|
||||
}
|
||||
if (k_row0 >= n_kv) s0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) s1 = FA_M_INIT;
|
||||
if (k_row2 >= n_kv) s2 = FA_M_INIT;
|
||||
if (k_row3 >= n_kv) s3 = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL && blk_cur != 2) {
|
||||
if (use_kv_pad && mask_pad_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
|
||||
s0 += slope * (ACC_TYPE)mask_ptr[j];
|
||||
s1 += slope * (ACC_TYPE)mask_ptr[j + 1];
|
||||
s2 += slope * (ACC_TYPE)mask_ptr[j + 2];
|
||||
s3 += slope * (ACC_TYPE)mask_ptr[j + 3];
|
||||
} else {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
|
||||
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
s0 = logit_softcap * tanh(s0 / logit_softcap);
|
||||
s1 = logit_softcap * tanh(s1 / logit_softcap);
|
||||
s2 = logit_softcap * tanh(s2 / logit_softcap);
|
||||
s3 = logit_softcap * tanh(s3 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
|
||||
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
|
||||
// far negative so the tile contributes 0, not exp(0)=1.
|
||||
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
|
||||
const ACC_TYPE scale_prev = native_exp(m_i - m_exp);
|
||||
const ACC_TYPE p0 = native_exp(s0 - m_exp);
|
||||
const ACC_TYPE p1 = native_exp(s1 - m_exp);
|
||||
const ACC_TYPE p2 = native_exp(s2 - m_exp);
|
||||
const ACC_TYPE p3 = native_exp(s3 - m_exp);
|
||||
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i]),
|
||||
mad(p2, CONVERT_KV_ACC4(l_v[j+2][i]),
|
||||
mad(p1, CONVERT_KV_ACC4(l_v[j+1][i]),
|
||||
mad(p0, CONVERT_KV_ACC4(l_v[j][i]),
|
||||
o_acc[i] * scale_prev))));
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// End of tile: every thread must finish reading l_k/l_v before the
|
||||
// next iteration's load overwrites them (WAR hazard on local memory).
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
if (my_query_row < n_q) {
|
||||
// Write output.
|
||||
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
|
||||
if (query_valid) {
|
||||
ACC_TYPE sinks_sp = 1.0f;
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
sinks_sp = exp(m_i - m_final);
|
||||
l_i = l_i * sinks_sp + exp(m_sink - m_final);
|
||||
m_i = m_final;
|
||||
}
|
||||
const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_inv > 0.0f) {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
|
||||
}
|
||||
} else {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif N_SPLIT > 1
|
||||
if (split_idx == 0) {
|
||||
ACC_TYPE sinks_sp = 1.0f;
|
||||
if (query_valid && sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
sinks_sp = exp(m_i - m_final);
|
||||
l_i = l_i * sinks_sp + exp(m_sink - m_final);
|
||||
m_i = m_final;
|
||||
}
|
||||
local_softmax_scale[q_lane] = sinks_sp;
|
||||
local_l_inv[q_lane] = (query_valid && l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (query_valid) {
|
||||
const ACC_TYPE sinks_sp = local_softmax_scale[q_lane];
|
||||
const ACC_TYPE l_inv = local_l_inv[q_lane];
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_inv > 0.0f) {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
|
||||
}
|
||||
} else {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (query_valid) {
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
@@ -194,17 +543,18 @@ __kernel void flash_attn_f32_f16(
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_i;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__kernel void flash_attn_f32_f16_q1(
|
||||
@@ -258,13 +608,16 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
||||
}
|
||||
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
// Q is uniform across WG threads (n_q=1). Share via local memory to
|
||||
// avoid per-thread q_priv[DK_VEC] dynamic-indexed private array that
|
||||
// spills to DDR on Adreno.
|
||||
__local ACC_TYPE4 q_shared[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
||||
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
|
||||
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
|
||||
@@ -273,14 +626,14 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
@@ -296,7 +649,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -304,7 +657,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
const ACC_TYPE m_final = local_m[0];
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
@@ -314,9 +667,9 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
||||
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
@@ -328,7 +681,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_final);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
@@ -338,7 +691,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -357,7 +710,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
local_o_comp[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -367,7 +720,257 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
}
|
||||
}
|
||||
} else if (tid == 0) {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// Flash-decoding split pass. gid(2) = q_idx * n_splits + split_idx.
|
||||
// Partial record per split: [m, l, O[DV]]. Merge kernel applies sink + norm.
|
||||
#define FA_PARTIAL_FLOATS (2 + DV)
|
||||
|
||||
__kernel void flash_attn_f32_f16_q1_split(
|
||||
const global void * q_void, ulong q_offset,
|
||||
const global void * k_void, ulong k_offset,
|
||||
const global void * v_void, ulong v_offset,
|
||||
const float scale,
|
||||
const int n_q,
|
||||
const int n_kv,
|
||||
const int n_head,
|
||||
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
||||
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
||||
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const int n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int n_head_kv,
|
||||
const global void * mask_void,
|
||||
const ulong mask_offset,
|
||||
const ulong mask_nb1,
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3,
|
||||
global float * partial_void,
|
||||
const int n_splits,
|
||||
const int kv_per_split
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
const int split_q_idx = get_global_id(2);
|
||||
const int split_idx = split_q_idx % n_splits;
|
||||
const int q_idx = split_q_idx / n_splits;
|
||||
const int batch_idx = head_batch_idx / n_head;
|
||||
const int head_idx = head_batch_idx % n_head;
|
||||
const int gqa_ratio = n_head / n_head_kv;
|
||||
const int head_kv_idx = head_idx / gqa_ratio;
|
||||
|
||||
const int kv_start = split_idx * kv_per_split;
|
||||
const int kv_end = min(kv_start + kv_per_split, n_kv);
|
||||
|
||||
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
|
||||
const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx)
|
||||
* n_splits + split_idx);
|
||||
global float * rec = partial_void + record_idx * record_stride;
|
||||
global float4 * rec_o = (global float4 *) (rec + 2);
|
||||
|
||||
if (kv_start >= kv_end) {
|
||||
// Empty split: leave sentinel partial for merge.
|
||||
if (tid == 0) {
|
||||
rec[0] = FA_M_INIT;
|
||||
rec[1] = 0.0f;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const global char * q_base = (const global char *) q_void + q_offset;
|
||||
const global char * k_base = (const global char *) k_void + k_offset;
|
||||
const global char * v_base = (const global char *) v_void + v_offset;
|
||||
|
||||
const global char * mask_base = NULL;
|
||||
if (mask_void != NULL) {
|
||||
const int mask_head_idx = head_idx % mask_ne2;
|
||||
const int mask_batch_idx = batch_idx % mask_ne3;
|
||||
mask_base = (const global char *) mask_void + mask_offset +
|
||||
mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 +
|
||||
(ulong) q_idx * mask_nb1;
|
||||
}
|
||||
|
||||
// Share Q via local memory (n_q=1 per split -> uniform across WG).
|
||||
__local ACC_TYPE4 q_shared[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1;
|
||||
const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset);
|
||||
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
|
||||
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
|
||||
// Pass 1a — split-local max.
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < DK_VEC; ++k) {
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
|
||||
score += slope * (ACC_TYPE) mask_ptr[k_idx];
|
||||
}
|
||||
if (logit_softcap > 0.0f) {
|
||||
score = logit_softcap * tanh(score / logit_softcap);
|
||||
}
|
||||
m_i = max(m_i, score);
|
||||
}
|
||||
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
const ACC_TYPE m_c = local_m[0];
|
||||
|
||||
// Pass 1b — softmax-weighted V accumulate.
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
|
||||
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
|
||||
const global KV_DATA_TYPE4 * v_ptr = (const global KV_DATA_TYPE4 *) (v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < DK_VEC; ++k) {
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
|
||||
score += slope * (ACC_TYPE) mask_ptr[k_idx];
|
||||
}
|
||||
if (logit_softcap > 0.0f) {
|
||||
score = logit_softcap * tanh(score / logit_softcap);
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_c);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
}
|
||||
|
||||
__local ACC_TYPE local_l[Q1_WG_SIZE];
|
||||
__local ACC_TYPE4 local_o[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
const ACC_TYPE l_c = local_l[0];
|
||||
|
||||
if (tid == 0) {
|
||||
rec[0] = (float) m_c;
|
||||
rec[1] = (float) l_c;
|
||||
}
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
local_o[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o[tid] += local_o[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
if (tid == 0) {
|
||||
rec_o[i] = local_o[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FD Pass 2: merge per-split partials into final O. Empty splits drop via exp(-INF)=0.
|
||||
__kernel void flash_attn_f32_merge(
|
||||
const global float * partial_void,
|
||||
global void * o_void,
|
||||
const ulong o_offset,
|
||||
const int n_head,
|
||||
const int n_splits,
|
||||
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
||||
const global void * sinks_void,
|
||||
const ulong sinks_offset,
|
||||
const int n_q
|
||||
) {
|
||||
const int lane = get_local_id(0); // 0..DV_VEC-1
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
const int q_idx = get_global_id(2);
|
||||
const int batch_idx = head_batch_idx / n_head;
|
||||
const int head_idx = head_batch_idx % n_head;
|
||||
|
||||
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
|
||||
const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits;
|
||||
const global float * rec0 = partial_void + record_idx_0 * record_stride;
|
||||
|
||||
__local ACC_TYPE m_final_shared;
|
||||
__local ACC_TYPE l_final_shared;
|
||||
if (lane == 0) {
|
||||
ACC_TYPE m = FA_M_INIT;
|
||||
for (int c = 0; c < n_splits; ++c) {
|
||||
const ACC_TYPE m_c = rec0[c * record_stride + 0];
|
||||
m = max(m, m_c);
|
||||
}
|
||||
ACC_TYPE m_sink = 0.0f;
|
||||
bool has_sink = false;
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE * sinks_ptr =
|
||||
(const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset);
|
||||
m_sink = sinks_ptr[head_idx];
|
||||
has_sink = true;
|
||||
m = max(m, m_sink);
|
||||
}
|
||||
ACC_TYPE l = 0.0f;
|
||||
for (int c = 0; c < n_splits; ++c) {
|
||||
const ACC_TYPE m_c = rec0[c * record_stride + 0];
|
||||
const ACC_TYPE l_c = rec0[c * record_stride + 1];
|
||||
if (m_c > FA_M_INIT) {
|
||||
l += l_c * exp(m_c - m);
|
||||
}
|
||||
}
|
||||
if (has_sink) {
|
||||
l += exp(m_sink - m);
|
||||
}
|
||||
m_final_shared = m;
|
||||
l_final_shared = l;
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
const ACC_TYPE m_final = m_final_shared;
|
||||
const ACC_TYPE l_final = l_final_shared;
|
||||
const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f;
|
||||
|
||||
ACC_TYPE4 o = (ACC_TYPE4)(0.0f);
|
||||
for (int c = 0; c < n_splits; ++c) {
|
||||
const global float * rec_c = rec0 + c * record_stride;
|
||||
const ACC_TYPE m_c = rec_c[0];
|
||||
if (m_c <= FA_M_INIT) continue;
|
||||
const global float4 * rec_oc = (const global float4 *) (rec_c + 2);
|
||||
const ACC_TYPE scale_c = exp(m_c - m_final);
|
||||
o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o);
|
||||
}
|
||||
o = o * l_inv;
|
||||
|
||||
const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset);
|
||||
o_row[lane] = CONVERT_O_DATA4(o);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,156 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
__kernel void flash_attn_kv_pad_f16(
|
||||
const global void * k_void, ulong k_offset,
|
||||
const global void * v_void, ulong v_offset,
|
||||
global void * k_pad_void,
|
||||
global void * v_pad_void,
|
||||
const int n_kv,
|
||||
const int n_head_kv,
|
||||
const int n_batch,
|
||||
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
||||
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3
|
||||
) {
|
||||
const int row_idx = get_global_id(0);
|
||||
const int head_kv_idx = get_global_id(1);
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
if (row_idx >= BLOCK_N || head_kv_idx >= n_head_kv || batch_idx >= n_batch) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tail_start = n_kv - (n_kv % BLOCK_N);
|
||||
const int src_row_idx = tail_start + row_idx;
|
||||
|
||||
const global char * k_src = (const global char *) k_void + k_offset;
|
||||
const global char * v_src = (const global char *) v_void + v_offset;
|
||||
global char * k_pad = (global char *) k_pad_void;
|
||||
global char * v_pad = (global char *) v_pad_void;
|
||||
|
||||
const ulong k_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * k_nb1) + (ulong) row_idx * k_nb1;
|
||||
const ulong v_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * v_nb1) + (ulong) row_idx * v_nb1;
|
||||
|
||||
if (src_row_idx < n_kv) {
|
||||
const ulong k_src_offset = (ulong) batch_idx * k_nb3 + (ulong) head_kv_idx * k_nb2 + (ulong) src_row_idx * k_nb1;
|
||||
const ulong v_src_offset = (ulong) batch_idx * v_nb3 + (ulong) head_kv_idx * v_nb2 + (ulong) src_row_idx * v_nb1;
|
||||
|
||||
for (ulong i = 0; i < k_nb1; ++i) {
|
||||
k_pad[k_dst_offset + i] = k_src[k_src_offset + i];
|
||||
}
|
||||
for (ulong i = 0; i < v_nb1; ++i) {
|
||||
v_pad[v_dst_offset + i] = v_src[v_src_offset + i];
|
||||
}
|
||||
} else {
|
||||
for (ulong i = 0; i < k_nb1; ++i) {
|
||||
k_pad[k_dst_offset + i] = 0;
|
||||
}
|
||||
for (ulong i = 0; i < v_nb1; ++i) {
|
||||
v_pad[v_dst_offset + i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void flash_attn_mask_pad_f16(
|
||||
const global void * mask_void, ulong mask_offset,
|
||||
global void * mask_pad_void,
|
||||
const int n_q,
|
||||
const int n_kv,
|
||||
const ulong mask_nb1,
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
) {
|
||||
const int col_idx = get_global_id(0);
|
||||
const int q_row = get_global_id(1);
|
||||
const int mask_slice = get_global_id(2);
|
||||
|
||||
if (col_idx >= BLOCK_N || q_row >= n_q || mask_slice >= mask_ne2 * mask_ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tail_start = n_kv - (n_kv % BLOCK_N);
|
||||
const int src_col_idx = tail_start + col_idx;
|
||||
const int mask_head_idx = mask_slice % mask_ne2;
|
||||
const int mask_batch_idx = mask_slice / mask_ne2;
|
||||
|
||||
const global char * mask_src_base = (const global char *) mask_void + mask_offset +
|
||||
(ulong) mask_batch_idx * mask_nb3 +
|
||||
(ulong) mask_head_idx * mask_nb2 +
|
||||
(ulong) q_row * mask_nb1;
|
||||
const global half * mask_src = (const global half *) mask_src_base;
|
||||
|
||||
global half * mask_pad = (global half *) mask_pad_void;
|
||||
const ulong dst_idx =
|
||||
(((ulong) mask_batch_idx * (ulong) mask_ne2 + (ulong) mask_head_idx) * (ulong) n_q + (ulong) q_row) * (ulong) BLOCK_N +
|
||||
(ulong) col_idx;
|
||||
|
||||
mask_pad[dst_idx] = src_col_idx < n_kv ? mask_src[src_col_idx] : (half) (-INFINITY);
|
||||
}
|
||||
|
||||
// Per-KV-tile mask class. 0=all -inf (skip tile), 1=mixed (apply mask),
|
||||
// 2=all zero, no -inf (skip mask lookup). Causal diagonal tiles are class 1.
|
||||
__kernel void flash_attn_blk_f16(
|
||||
const global void * mask_void, ulong mask_offset,
|
||||
global char * blk,
|
||||
const int n_q,
|
||||
const int n_kv,
|
||||
const ulong mask_nb1,
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
) {
|
||||
const int kv_block_idx = get_global_id(0);
|
||||
const int q_block_idx = get_global_id(1);
|
||||
const int mask_slice = get_global_id(2);
|
||||
|
||||
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
|
||||
const int n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N;
|
||||
if (kv_block_idx >= n_kv_blocks || q_block_idx >= n_q_blocks || mask_slice >= mask_ne2 * mask_ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int mask_head_idx = mask_slice % mask_ne2;
|
||||
const int mask_batch_idx = mask_slice / mask_ne2;
|
||||
const int q_start = q_block_idx * BLOCK_M;
|
||||
const int k_start = kv_block_idx * BLOCK_N;
|
||||
const int q_count = min(BLOCK_M, n_q - q_start);
|
||||
const int k_count = min(BLOCK_N, n_kv - k_start);
|
||||
|
||||
const half neg_max_half = (half) (-65504.0f);
|
||||
char has_unmasked = 0;
|
||||
char has_masked = 0;
|
||||
char has_nonzero = 0;
|
||||
|
||||
const global char * mask_base = (const global char *) mask_void + mask_offset +
|
||||
(ulong) mask_batch_idx * mask_nb3 +
|
||||
(ulong) mask_head_idx * mask_nb2;
|
||||
|
||||
for (int qi = 0; qi < q_count; ++qi) {
|
||||
const global half * mask_row = (const global half *) (mask_base + (ulong) (q_start + qi) * mask_nb1) + k_start;
|
||||
for (int ki = 0; ki < k_count; ++ki) {
|
||||
const half v = mask_row[ki];
|
||||
if (v <= neg_max_half) {
|
||||
has_masked = 1;
|
||||
} else {
|
||||
has_unmasked = 1;
|
||||
if (v != (half) 0.0f) {
|
||||
has_nonzero = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (has_masked && has_unmasked) break; // mixed tile — short-circuit.
|
||||
}
|
||||
|
||||
char res;
|
||||
if (has_unmasked == 0) {
|
||||
res = 0;
|
||||
} else if (has_masked || has_nonzero) {
|
||||
res = 1;
|
||||
} else {
|
||||
res = 2;
|
||||
}
|
||||
|
||||
blk[((ulong) mask_slice * (ulong) n_q_blocks + (ulong) q_block_idx) * (ulong) n_kv_blocks + (ulong) kv_block_idx] = res;
|
||||
}
|
||||
@@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32(
|
||||
regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
|
||||
regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
|
||||
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
|
||||
dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB);
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave=4
|
||||
|
||||
@@ -24,6 +24,7 @@ kernel void kernel_norm(
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne03,
|
||||
ulong nb00,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
@@ -43,7 +44,8 @@ kernel void kernel_norm(
|
||||
// parallel sum
|
||||
sum[get_local_id(0)] = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
sum[get_local_id(0)] += x[i00];
|
||||
// this kernel handles float, nb00/4 translates byte offset to element offset
|
||||
sum[get_local_id(0)] += x[i00*nb00/4];
|
||||
}
|
||||
// reduce
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -60,7 +62,8 @@ kernel void kernel_norm(
|
||||
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
sum[get_local_id(0)] = 0.0f;
|
||||
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
|
||||
y[i00] = x[i00] - mean;
|
||||
// this kernel handles float, nb00/4 translates byte offset to element offset
|
||||
y[i00] = x[i00*nb00/4] - mean;
|
||||
sum[get_local_id(0)] += y[i00] * y[i00];
|
||||
}
|
||||
|
||||
|
||||
@@ -158,6 +158,239 @@ kernel void kernel_set_rows_f32_i32(
|
||||
}
|
||||
}
|
||||
|
||||
// f32 -> q8_0 quantize set_rows. Block = half d + char qs[32].
|
||||
#define QK8_0 32
|
||||
|
||||
inline void quantize_q8_0_block(global float * x, global char * qs, global half * d_out) {
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
amax = fmax(amax, fabs(x[j]));
|
||||
}
|
||||
|
||||
float d = amax / 127.0f;
|
||||
float id = (d != 0.0f) ? 127.0f / amax : 0.0f;
|
||||
|
||||
vstore_half(d, 0, d_out);
|
||||
|
||||
for (int j = 0; j < QK8_0; j++) {
|
||||
qs[j] = (char)((int)round(x[j] * id));
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q8_0_i64(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK8_0;
|
||||
global char * y = dst_row + blk * (2 + QK8_0);
|
||||
|
||||
quantize_q8_0_block(x, y + 2, (global half *)y);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q8_0_i32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK8_0;
|
||||
global char * y = dst_row + blk * (2 + QK8_0);
|
||||
|
||||
quantize_q8_0_block(x, y + 2, (global half *)y);
|
||||
}
|
||||
}
|
||||
|
||||
// SoA q8_0 variants. dst_q: int8[QK8_0] per block; dst_d: fp16 scale per block.
|
||||
// Layout matches kernel_convert_block_q8_0; block index follows dst element order.
|
||||
kernel void kernel_set_rows_q8_0_soa_i64(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst_q,
|
||||
ulong offset_q,
|
||||
global char * dst_d,
|
||||
ulong offset_d,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
int ne1_dst,
|
||||
int ne2_dst,
|
||||
int ne3_dst
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst_q = dst_q + offset_q;
|
||||
dst_d = dst_d + offset_d;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
|
||||
|
||||
global half * d_row = (global half *)(dst_d) + row_blk_base;
|
||||
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
|
||||
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK8_0;
|
||||
global char * q = q_row + blk * QK8_0;
|
||||
|
||||
quantize_q8_0_block(x, q, d_row + blk);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q8_0_soa_i32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst_q,
|
||||
ulong offset_q,
|
||||
global char * dst_d,
|
||||
ulong offset_d,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
int ne1_dst,
|
||||
int ne2_dst,
|
||||
int ne3_dst
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst_q = dst_q + offset_q;
|
||||
dst_d = dst_d + offset_d;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
|
||||
|
||||
global half * d_row = (global half *)(dst_d) + row_blk_base;
|
||||
global char * q_row = (global char *)(dst_q) + row_blk_base * QK8_0;
|
||||
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK8_0;
|
||||
global char * q = q_row + blk * QK8_0;
|
||||
|
||||
quantize_q8_0_block(x, q, d_row + blk);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_f16_i32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
@@ -206,3 +439,270 @@ kernel void kernel_set_rows_f16_i32(
|
||||
dst_row[ind] = src_row[ind];
|
||||
}
|
||||
}
|
||||
|
||||
// f32 -> q4_0 quantize set_rows. Block = half d + uchar qs[16] (shuffled
|
||||
// nibbles: qs[j] low/high = elem j / j+16).
|
||||
// Dequant: val[i] = d * (nibble_i - 8)
|
||||
// nblk0 = number of q4_0 blocks per row = ne00 / 32.
|
||||
#define QK4_0 32
|
||||
#define Q4_0_BLOCK_SIZE 18
|
||||
|
||||
inline void quantize_q4_0_block(global float * x, global uchar * qs, global half * d_out) {
|
||||
// Find the signed value with the largest absolute magnitude (matches ggml ref).
|
||||
float max = 0.0f;
|
||||
float amax = 0.0f;
|
||||
for (int j = 0; j < QK4_0; j++) {
|
||||
float v = x[j];
|
||||
float a = fabs(v);
|
||||
if (a > amax) {
|
||||
amax = a;
|
||||
max = v;
|
||||
}
|
||||
}
|
||||
|
||||
float d = max / -8.0f;
|
||||
float id = (d != 0.0f) ? 1.0f / d : 0.0f;
|
||||
|
||||
vstore_half(d, 0, d_out);
|
||||
|
||||
for (int j = 0; j < QK4_0/2; j++) {
|
||||
float x0 = x[j] * id;
|
||||
float x1 = x[j + QK4_0/2] * id;
|
||||
|
||||
int i0 = (int)(x0 + 8.5f);
|
||||
int i1 = (int)(x1 + 8.5f);
|
||||
if (i0 < 0) i0 = 0;
|
||||
if (i0 > 15) i0 = 15;
|
||||
if (i1 < 0) i1 = 0;
|
||||
if (i1 > 15) i1 = 15;
|
||||
|
||||
qs[j] = (uchar)i0 | ((uchar)i1 << 4);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q4_0_i64(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK4_0;
|
||||
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
|
||||
global half * yd = (global half *)(y);
|
||||
global uchar * yqs = (global uchar *)(y + 2);
|
||||
|
||||
quantize_q4_0_block(x, yqs, yd);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q4_0_i32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
ulong nb1,
|
||||
ulong nb2,
|
||||
ulong nb3
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst = dst + offsetd;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
global char * dst_row = (global char *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
|
||||
global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK4_0;
|
||||
global char * y = dst_row + blk * Q4_0_BLOCK_SIZE;
|
||||
global half * yd = (global half *)(y);
|
||||
global uchar * yqs = (global uchar *)(y + 2);
|
||||
|
||||
quantize_q4_0_block(x, yqs, yd);
|
||||
}
|
||||
}
|
||||
|
||||
// SoA variants for q4_0 dst. Used when the backend has split block_q4_0 records
|
||||
// into separate quant (dst_q) and scale (dst_d) sub-buffers — same pattern as
|
||||
// the q8_0 SoA variants above.
|
||||
//
|
||||
// Layout (matches kernel_convert_block_q4_0, the "shuffled" variant):
|
||||
// dst_q: contiguous 16 packed nibbles per block, block i at offset i * 16 bytes.
|
||||
// dst_d: contiguous fp16 scales, block i at offset i * 2 bytes.
|
||||
// Nibble layout inside each byte is unchanged from AoS: qs[j] low nibble = element j,
|
||||
// qs[j] high nibble = element j+16. kernel_restore_block_q4_0 copies bytes as-is.
|
||||
kernel void kernel_set_rows_q4_0_soa_i64(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst_q,
|
||||
ulong offset_q,
|
||||
global char * dst_d,
|
||||
ulong offset_d,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
int ne1_dst,
|
||||
int ne2_dst,
|
||||
int ne3_dst
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst_q = dst_q + offset_q;
|
||||
dst_d = dst_d + offset_d;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
|
||||
|
||||
global half * d_row = (global half *)(dst_d) + row_blk_base;
|
||||
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
|
||||
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK4_0;
|
||||
global uchar * qs = q_row + blk * (QK4_0/2);
|
||||
global half * d_bk = d_row + blk;
|
||||
|
||||
quantize_q4_0_block(x, qs, d_bk);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_set_rows_q4_0_soa_i32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst_q,
|
||||
ulong offset_q,
|
||||
global char * dst_d,
|
||||
ulong offset_d,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
uint4 ne11,
|
||||
uint4 ne12,
|
||||
ulong nb10,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
int nblk0,
|
||||
int ne1_dst,
|
||||
int ne2_dst,
|
||||
int ne3_dst
|
||||
) {
|
||||
src0 = src0 + offset0;
|
||||
src1 = src1 + offset1;
|
||||
dst_q = dst_q + offset_q;
|
||||
dst_d = dst_d + offset_d;
|
||||
|
||||
int i03 = get_group_id(2);
|
||||
int i02 = get_group_id(1);
|
||||
int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
|
||||
|
||||
if (i01 >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
int i12 = fastmod(i03, ne12);
|
||||
int i11 = fastmod(i02, ne11);
|
||||
|
||||
int i10 = i01;
|
||||
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
|
||||
|
||||
long row_blk_base = ((long)i03 * ne2_dst * ne1_dst + (long)i02 * ne1_dst + i1) * nblk0;
|
||||
|
||||
global half * d_row = (global half *)(dst_d) + row_blk_base;
|
||||
global uchar * q_row = (global uchar *)(dst_q) + row_blk_base * (QK4_0/2);
|
||||
global float * src_row = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
for (int blk = get_local_id(0); blk < nblk0; blk += get_local_size(0)) {
|
||||
global float * x = src_row + blk * QK4_0;
|
||||
global uchar * qs = q_row + blk * (QK4_0/2);
|
||||
global half * d_bk = d_row + blk;
|
||||
|
||||
quantize_q4_0_block(x, qs, d_bk);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1270,77 +1270,14 @@ void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecode
|
||||
}
|
||||
|
||||
std::string GgmlOvDecoder::compute_op_type(const ggml_tensor * node) {
|
||||
static const std::map<ggml_op, std::string> ops = {
|
||||
{GGML_OP_NONE, "GGML_OP_NONE" },
|
||||
{GGML_OP_ACC, "GGML_OP_ACC" },
|
||||
{GGML_OP_ADD, "GGML_OP_ADD" },
|
||||
{GGML_OP_ADD1, "GGML_OP_ADD1" },
|
||||
{GGML_OP_ADD_ID, "GGML_OP_ADD_ID" },
|
||||
{GGML_OP_CONCAT, "GGML_OP_CONCAT" },
|
||||
{GGML_OP_CONT, "GGML_OP_CONT" },
|
||||
{GGML_OP_DIV, "GGML_OP_DIV" },
|
||||
{GGML_OP_DUP, "GGML_OP_DUP" },
|
||||
{GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS" },
|
||||
{GGML_OP_MUL, "GGML_OP_MUL" },
|
||||
{GGML_OP_MUL_MAT, "GGML_OP_MUL_MAT" },
|
||||
{GGML_OP_MUL_MAT_ID, "GGML_OP_MUL_MAT_ID" },
|
||||
{GGML_OP_PERMUTE, "GGML_OP_PERMUTE" },
|
||||
{GGML_OP_RESHAPE, "GGML_OP_RESHAPE" },
|
||||
{GGML_OP_RMS_NORM, "GGML_OP_RMS_NORM" },
|
||||
{GGML_OP_NORM, "GGML_OP_NORM" },
|
||||
{GGML_OP_ROPE, "GGML_OP_ROPE" },
|
||||
{GGML_OP_SCALE, "GGML_OP_SCALE" },
|
||||
{GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" },
|
||||
{GGML_OP_SUM_ROWS, "GGML_OP_SUM_ROWS" },
|
||||
{GGML_OP_SUB, "GGML_OP_SUB" },
|
||||
{GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE" },
|
||||
{GGML_OP_VIEW, "GGML_OP_VIEW" },
|
||||
{GGML_OP_SET_ROWS, "GGML_OP_SET_ROWS" },
|
||||
{GGML_OP_CPY, "GGML_OP_CPY" },
|
||||
{GGML_OP_FLASH_ATTN_EXT, "GGML_OP_FLASH_ATTN_EXT" },
|
||||
{GGML_OP_L2_NORM, "GGML_OP_L2_NORM" },
|
||||
{GGML_OP_CLAMP, "GGML_OP_CLAMP" },
|
||||
{GGML_OP_PAD, "GGML_OP_PAD" },
|
||||
{GGML_OP_SSM_CONV, "GGML_OP_SSM_CONV" },
|
||||
{GGML_OP_GATED_DELTA_NET, "GGML_OP_GATED_DELTA_NET"},
|
||||
{GGML_OP_ARGSORT, "GGML_OP_ARGSORT" },
|
||||
{GGML_OP_REPEAT, "GGML_OP_REPEAT" },
|
||||
{GGML_OP_IM2COL, "GGML_OP_IM2COL" }
|
||||
};
|
||||
static const std::map<ggml_unary_op, std::string> unary_ops = {
|
||||
{GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS" },
|
||||
{GGML_UNARY_OP_SGN, "GGML_UNARY_OP_SGN" },
|
||||
{GGML_UNARY_OP_NEG, "GGML_UNARY_OP_NEG" },
|
||||
{GGML_UNARY_OP_STEP, "GGML_UNARY_OP_STEP" },
|
||||
{GGML_UNARY_OP_TANH, "GGML_UNARY_OP_TANH" },
|
||||
{GGML_UNARY_OP_ELU, "GGML_UNARY_OP_ELU" },
|
||||
{GGML_UNARY_OP_RELU, "GGML_UNARY_OP_RELU" },
|
||||
{GGML_UNARY_OP_SIGMOID, "GGML_UNARY_OP_SIGMOID" },
|
||||
{GGML_UNARY_OP_GELU, "GGML_UNARY_OP_GELU" },
|
||||
{GGML_UNARY_OP_GELU_QUICK, "GGML_UNARY_OP_GELU_QUICK" },
|
||||
{GGML_UNARY_OP_SILU, "GGML_UNARY_OP_SILU" },
|
||||
{GGML_UNARY_OP_SOFTPLUS, "GGML_UNARY_OP_SOFTPLUS" },
|
||||
{GGML_UNARY_OP_HARDSWISH, "GGML_UNARY_OP_HARDSWISH" },
|
||||
{GGML_UNARY_OP_HARDSIGMOID, "GGML_UNARY_OP_HARDSIGMOID"},
|
||||
{GGML_UNARY_OP_EXP, "GGML_UNARY_OP_EXP" },
|
||||
{GGML_UNARY_OP_COUNT, "GGML_UNARY_OP_COUNT" }
|
||||
};
|
||||
static const std::map<ggml_glu_op, std::string> glu_ops = {
|
||||
{GGML_GLU_OP_SWIGLU, "GGML_GLU_OP_SWIGLU"},
|
||||
{GGML_GLU_OP_GEGLU, "GGML_GLU_OP_GEGLU" },
|
||||
{GGML_GLU_OP_REGLU, "GGML_GLU_OP_REGLU" }
|
||||
};
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_UNARY:
|
||||
return unary_ops.at(ggml_get_unary_op(node));
|
||||
return std::string("GGML_UNARY_OP_") + ggml_unary_op_name(ggml_get_unary_op(node));
|
||||
case GGML_OP_GLU:
|
||||
return glu_ops.at(ggml_get_glu_op(node));
|
||||
return std::string("GGML_GLU_OP_") + ggml_glu_op_name(ggml_get_glu_op(node));
|
||||
default:
|
||||
return ops.at(node->op);
|
||||
return std::string("GGML_OP_") + ggml_op_name(node->op);
|
||||
}
|
||||
static const std::string unknown_op = "UNKNOWN_GGML_OP";
|
||||
return unknown_op;
|
||||
}
|
||||
|
||||
const std::string & GgmlOvDecoder::get_op_type(int node_idx) const {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user