mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-07-03 20:53:08 +02:00
Compare commits
56 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d4cff114c0 | |||
| f113e02d5a | |||
| 152d337fad | |||
| 75a48a9055 | |||
| 067de93718 | |||
| b5315e16e0 | |||
| 94875285e4 | |||
| 5a460dea9f | |||
| c8ae9a750c | |||
| fdb1db877c | |||
| 4fc4ec5541 | |||
| a6647b1a32 | |||
| 13e673863b | |||
| b820cc8e6f | |||
| 6dbc1174b8 | |||
| 9d88e7cedd | |||
| 7af4279f45 | |||
| fd1a05791d | |||
| 0eca4d490e | |||
| 4f31eedb0c | |||
| 799fcc04a5 | |||
| 931eb37f8c | |||
| e495d1e748 | |||
| f708a5b2ca | |||
| d9df11006f | |||
| 6c5de1cc83 | |||
| 86b94708f2 | |||
| 6f4f53f2b7 | |||
| 25a1d63f43 | |||
| 8c146a8366 | |||
| 6cb18b2f2e | |||
| 277a105dc8 | |||
| b3fed31b99 | |||
| dbdaece23d | |||
| 7cb8576e7c | |||
| fa72bc6826 | |||
| c818263f2a | |||
| f68a788b0b | |||
| d1b34251bc | |||
| c1a1c8ee94 | |||
| 27c8bb4f63 | |||
| ebd048fc5e | |||
| 0ed235ea2c | |||
| 9bebfcb4bc | |||
| 0b6529d818 | |||
| c299a92c38 | |||
| 0275c0f800 | |||
| 83d385b429 | |||
| 050ee92d04 | |||
| 3fc4e10527 | |||
| 5d8ccdf9d1 | |||
| 024930c6ad | |||
| 5397c36194 | |||
| e7ea94afcb | |||
| 96183e9820 | |||
| 487a6cc164 |
@@ -145,7 +145,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
# ==============================================================================
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
ENTRYPOINT [ "/app/llama-cli" ]
|
||||
|
||||
@@ -156,7 +156,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
HEALTHCHECK --interval=5m CMD [ "curl", "-f", "http://localhost:8080/health" ]
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -115,7 +115,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -124,7 +124,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -153,7 +153,7 @@ FROM base AS server
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/lib/ /app
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -126,7 +126,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
ARG OPENVINO_VERSION_MAJOR=2026.2
|
||||
ARG OPENVINO_VERSION_FULL=2026.2.0.21903.52ddc073857
|
||||
ARG OPENVINO_VERSION_MAJOR=2026.2.1
|
||||
ARG OPENVINO_VERSION_FULL=2026.2.1.21919.ede283a88e3
|
||||
ARG UBUNTU_VERSION=24.04
|
||||
|
||||
# Intel GPU driver versions. https://github.com/intel/compute-runtime/releases
|
||||
ARG IGC_VERSION=v2.34.4
|
||||
ARG IGC_VERSION_FULL=2_2.34.4+21428
|
||||
ARG COMPUTE_RUNTIME_VERSION=26.18.38308.1
|
||||
ARG COMPUTE_RUNTIME_VERSION_FULL=26.18.38308.1-0
|
||||
ARG IGC_VERSION=v2.36.3
|
||||
ARG IGC_VERSION_FULL=2_2.36.3+21719
|
||||
ARG COMPUTE_RUNTIME_VERSION=26.22.38646.4
|
||||
ARG COMPUTE_RUNTIME_VERSION_FULL=26.22.38646.4-0
|
||||
ARG IGDGMM_VERSION=22.10.0
|
||||
|
||||
# Intel NPU driver versions. https://github.com/intel/linux-npu-driver/releases
|
||||
@@ -214,7 +214,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app/
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -225,7 +225,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app/
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -138,7 +138,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ WORKDIR /llama.cpp/bin
|
||||
|
||||
# Copy llama.cpp binaries and libraries
|
||||
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama-cli /llama.cpp/bin/llama-completion /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama /llama.cpp/bin/llama-cli /llama.cpp/bin/llama-completion /llama.cpp/bin
|
||||
|
||||
ENTRYPOINT [ "/llama.cpp/bin/llama-cli" ]
|
||||
|
||||
@@ -138,7 +138,7 @@ WORKDIR /llama.cpp/bin
|
||||
|
||||
# Copy llama.cpp binaries and libraries
|
||||
COPY --from=collector /llama.cpp/bin/*.so /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama-server /llama.cpp/bin
|
||||
COPY --from=collector /llama.cpp/bin/llama /llama.cpp/bin/llama-server /llama.cpp/bin
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -118,7 +118,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ ENTRYPOINT ["/app/tools.sh"]
|
||||
### Light, CLI only
|
||||
FROM base AS light
|
||||
|
||||
COPY --from=build /app/full/llama-cli /app/full/llama-completion /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-cli /app/full/llama-completion /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -108,7 +108,7 @@ FROM base AS server
|
||||
|
||||
ENV LLAMA_ARG_HOST=0.0.0.0
|
||||
|
||||
COPY --from=build /app/full/llama-server /app
|
||||
COPY --from=build /app/full/llama /app/full/llama-server /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -68,8 +68,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -96,8 +96,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -39,8 +39,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
@@ -96,8 +96,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -266,8 +266,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
|
||||
@@ -446,8 +446,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Set OpenVINO version output
|
||||
@@ -506,8 +506,11 @@ jobs:
|
||||
cmake -B build/ReleaseOV -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGGML_OPENVINO=ON \
|
||||
-DHF_UI_VERSION=${{ needs.get-version.outputs.ui_version }}
|
||||
cmake --build build/ReleaseOV --config Release -j $(nproc)
|
||||
-DCMAKE_INSTALL_RPATH='$ORIGIN' \
|
||||
-DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \
|
||||
-DHF_UI_VERSION=${{ needs.get-version.outputs.ui_version }} \
|
||||
${{ env.CMAKE_ARGS }}
|
||||
cmake --build build/ReleaseOV --config Release --parallel
|
||||
|
||||
- name: ccache-clear
|
||||
uses: ./.github/actions/ccache-clear
|
||||
@@ -521,8 +524,26 @@ jobs:
|
||||
- name: Pack artifacts
|
||||
id: pack_artifacts
|
||||
run: |
|
||||
cp LICENSE ./build/ReleaseOV/bin/
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz --transform "s,^\.,llama-${{ steps.tag.outputs.name }}," -C ./build/ReleaseOV/bin .
|
||||
dest=./build/ReleaseOV/bin
|
||||
OPENVINO_ROOT=./openvino_toolkit
|
||||
ov_lib="$OPENVINO_ROOT/runtime/lib/intel64"
|
||||
|
||||
# Bundle OpenVINO runtime libs + TBB. Binaries built with RPATH=$ORIGIN
|
||||
# load these siblings without setupvars.sh / LD_LIBRARY_PATH.
|
||||
cp -P "$ov_lib"/libopenvino.so* \
|
||||
"$ov_lib"/libopenvino_c.so* \
|
||||
"$ov_lib"/libopenvino_*_plugin.so \
|
||||
"$ov_lib"/libopenvino_intel_npu_compiler*.so \
|
||||
"$OPENVINO_ROOT"/runtime/3rdparty/tbb/lib/*.so* \
|
||||
"$dest"
|
||||
cp -P /usr/lib/x86_64-linux-gnu/libOpenCL.so.1* "$dest" 2>/dev/null || true
|
||||
cp "$ov_lib"/cache.json "$dest" 2>/dev/null || true
|
||||
|
||||
# OpenVINO licensing
|
||||
cp -r "$OPENVINO_ROOT"/docs/licensing "$dest"/openvino-licensing
|
||||
|
||||
cp LICENSE "$dest"
|
||||
tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz --transform "s,^\.,llama-${{ steps.tag.outputs.name }}," -C "$dest" .
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
@@ -531,6 +552,9 @@ jobs:
|
||||
name: llama-bin-ubuntu-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.tar.gz
|
||||
|
||||
windows-openvino:
|
||||
needs: [check-release]
|
||||
if: ${{ needs.check-release.outputs.should_release == 'true' }}
|
||||
|
||||
runs-on: windows-2022
|
||||
|
||||
outputs:
|
||||
@@ -538,8 +562,8 @@ jobs:
|
||||
|
||||
env:
|
||||
# Sync versions in build-openvino.yml, build-self-hosted.yml, release.yml, build-cache.yml, .devops/openvino.Dockerfile
|
||||
OPENVINO_VERSION_MAJOR: "2026.2"
|
||||
OPENVINO_VERSION_FULL: "2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR: "2026.2.1"
|
||||
OPENVINO_VERSION_FULL: "2026.2.1.21919.ede283a88e3"
|
||||
|
||||
steps:
|
||||
- name: Set OpenVINO version output
|
||||
@@ -607,7 +631,9 @@ jobs:
|
||||
-A x64 ^
|
||||
-DCMAKE_BUILD_TYPE=Release ^
|
||||
-DGGML_OPENVINO=ON ^
|
||||
-DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake
|
||||
-DLLAMA_BUILD_BORINGSSL=ON ^
|
||||
-DCMAKE_TOOLCHAIN_FILE=C:\vcpkg\scripts\buildsystems\vcpkg.cmake ^
|
||||
${{ env.CMAKE_ARGS }}
|
||||
|
||||
cmake --build build\ReleaseOV --config Release -- /m
|
||||
|
||||
@@ -624,8 +650,29 @@ jobs:
|
||||
id: pack_artifacts
|
||||
shell: powershell
|
||||
run: |
|
||||
Copy-Item LICENSE .\build\ReleaseOV\bin\
|
||||
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.zip .\build\ReleaseOV\bin\*
|
||||
# Locate the extracted OpenVINO toolkit root (same pattern as the Build step).
|
||||
$OPENVINO_ROOT = (Get-ChildItem -Directory openvino_toolkit | Select-Object -First 1).FullName
|
||||
if (-not $OPENVINO_ROOT) {
|
||||
Write-Error "OpenVINO toolkit folder not found under .\openvino_toolkit"
|
||||
exit 1
|
||||
}
|
||||
|
||||
$dest = ".\build\ReleaseOV\bin\Release"
|
||||
|
||||
$ovBin = Join-Path $OPENVINO_ROOT 'runtime\bin\intel64\Release'
|
||||
Copy-Item -Path (Join-Path $ovBin '*.dll') -Destination $dest -Force
|
||||
Copy-Item -Path (Join-Path $ovBin 'cache.json') -Destination $dest -Force
|
||||
|
||||
$tbbBin = Join-Path $OPENVINO_ROOT 'runtime\3rdparty\tbb\bin'
|
||||
Copy-Item -Path (Join-Path $tbbBin 'tbb*.dll') -Destination $dest -Force
|
||||
|
||||
# OpenVINO licensing
|
||||
$licensingDest = Join-Path $dest 'openvino-licensing'
|
||||
New-Item -ItemType Directory -Force -Path $licensingDest | Out-Null
|
||||
Copy-Item -Path (Join-Path $OPENVINO_ROOT 'docs\licensing\*') -Destination $licensingDest -Recurse -Force
|
||||
|
||||
Copy-Item LICENSE $dest
|
||||
7z a -snl llama-${{ steps.tag.outputs.name }}-bin-win-openvino-${{ env.OPENVINO_VERSION_MAJOR }}-x64.zip $dest\*
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
|
||||
+1
-1
@@ -80,7 +80,7 @@ To protect sensitive data from potential leaks or unauthorized access, it is cru
|
||||
### Untrusted environments or networks
|
||||
|
||||
If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:
|
||||
* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
|
||||
* Do not use the RPC backend, [ggml-rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
|
||||
* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value.
|
||||
* Encrypt your data if sending it over the network.
|
||||
|
||||
|
||||
+8
-4
@@ -50,6 +50,7 @@ struct command {
|
||||
std::vector<std::string> aliases;
|
||||
bool hidden;
|
||||
int (*func)(int, char **);
|
||||
bool flags = false; // allow --name
|
||||
};
|
||||
|
||||
#ifdef LLAMA_INSTALL_BUILD
|
||||
@@ -69,9 +70,9 @@ static const command cmds[] = {
|
||||
{"fit-params", "Compute parameters to fit a model in device memory", {}, true, llama_fit_params },
|
||||
{"quantize", "Quantize a model", {}, true, llama_quantize },
|
||||
{"perplexity", "Compute model perplexity and KL divergence", {}, true, llama_perplexity },
|
||||
{"version", "Show version", {}, false, version },
|
||||
{"licenses", "Show third-party licenses", {"credits"}, false, licenses },
|
||||
{"help", "Show available commands", {}, false, help },
|
||||
{"version", "Show version", {}, false, version, true },
|
||||
{"licenses", "Show third-party licenses", {"credits"}, false, licenses, true },
|
||||
{"help", "Show available commands", {}, false, help, true },
|
||||
};
|
||||
|
||||
#undef UPDATE_HIDDEN
|
||||
@@ -108,7 +109,10 @@ static int help(int argc, char ** argv) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static bool matches(const std::string & arg, const command & cmd) {
|
||||
static bool matches(std::string arg, const command & cmd) {
|
||||
if (cmd.flags && arg.size() > 2 && arg[0] == '-' && arg[1] == '-') {
|
||||
arg.erase(0, 2);
|
||||
}
|
||||
if (arg == cmd.name) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -94,10 +94,8 @@ add_library(${TARGET}
|
||||
peg-parser.h
|
||||
preset.cpp
|
||||
preset.h
|
||||
regex-partial.cpp
|
||||
reasoning-budget.cpp
|
||||
reasoning-budget.h
|
||||
regex-partial.h
|
||||
sampling.cpp
|
||||
sampling.h
|
||||
speculative.cpp
|
||||
|
||||
+67
-13
@@ -352,6 +352,8 @@ static std::string get_default_local_path(const std::string & url) {
|
||||
|
||||
common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex) {
|
||||
common_download_hf_plan plan;
|
||||
common_download_hf_plan plan_spec;
|
||||
common_download_hf_plan plan_voc;
|
||||
common_download_opts opts;
|
||||
|
||||
const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(),
|
||||
@@ -377,7 +379,15 @@ common_models_handler common_models_handler_init(const common_params & params, l
|
||||
plan = common_download_get_hf_plan(params.model, opts);
|
||||
}
|
||||
|
||||
return common_models_handler{plan, opts};
|
||||
if (!params.speculative.draft.mparams.hf_repo.empty()) {
|
||||
plan_spec = common_download_get_hf_plan(params.speculative.draft.mparams, opts);
|
||||
}
|
||||
|
||||
if (!params.vocoder.model.hf_repo.empty()) {
|
||||
plan_voc = common_download_get_hf_plan(params.vocoder.model, opts);
|
||||
}
|
||||
|
||||
return common_models_handler{plan, plan_spec, plan_voc, opts};
|
||||
}
|
||||
|
||||
bool common_models_handler_is_preset_repo(const common_models_handler & handler) {
|
||||
@@ -425,7 +435,9 @@ static std::vector<common_download_task> build_url_tasks(const common_params_mod
|
||||
void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback) {
|
||||
std::vector<common_download_task> tasks;
|
||||
|
||||
auto & plan = handler.plan;
|
||||
auto & plan = handler.plan;
|
||||
auto & plan_spec = handler.plan_spec;
|
||||
auto & plan_voc = handler.plan_voc;
|
||||
|
||||
auto opts = handler.opts; // copy
|
||||
opts.callback = callback;
|
||||
@@ -455,7 +467,7 @@ void common_models_handler_apply(common_models_handler & handler, common_params
|
||||
// the first part is what gets loaded, so point params.model.path at it
|
||||
if (!url_tasks.empty()) {
|
||||
std::string first_path = url_tasks.front().local_path;
|
||||
url_tasks.front().on_done = [&]() { params.model.path = first_path; };
|
||||
url_tasks.front().on_done = [&, first_path]() { params.model.path = first_path; };
|
||||
}
|
||||
for (auto & task : url_tasks) {
|
||||
tasks.push_back(std::move(task));
|
||||
@@ -484,19 +496,24 @@ void common_models_handler_apply(common_models_handler & handler, common_params
|
||||
}
|
||||
|
||||
// handle hf_plan tasks
|
||||
if (!plan.model_files.empty()) {
|
||||
for (size_t i = 0; i < plan.model_files.size(); ++i) {
|
||||
auto & model_file = plan.model_files[i];
|
||||
bool is_first = (i == 0);
|
||||
tasks.emplace_back(model_file, opts, [&, is_first]() {
|
||||
if (is_first) {
|
||||
// only use first part as model path
|
||||
params.model.path = hf_cache::finalize_file(model_file);
|
||||
auto add_tasks = [&opts, &tasks](const hf_cache::hf_files & model_files,
|
||||
const hf_cache::hf_file & primary,
|
||||
common_params_model & model) {
|
||||
for (size_t i = 0; i < model_files.size(); ++i) {
|
||||
auto & model_file = model_files[i];
|
||||
bool is_primary = (model_file.path == primary.path);
|
||||
tasks.emplace_back(model_file, opts, [&, is_primary]() {
|
||||
if (is_primary) {
|
||||
// the primary file is the first split (00001-of), use it as model path
|
||||
model.path = hf_cache::finalize_file(model_file);
|
||||
} else {
|
||||
hf_cache::finalize_file(model_file);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
if (!plan.model_files.empty()) {
|
||||
add_tasks(plan.model_files, plan.primary, params.model);
|
||||
}
|
||||
if (!plan.mmproj.local_path.empty()) {
|
||||
tasks.emplace_back(plan.mmproj, opts, [&]() {
|
||||
@@ -522,9 +539,31 @@ void common_models_handler_apply(common_models_handler & handler, common_params
|
||||
});
|
||||
}
|
||||
|
||||
// handle plan_spec (e.g. --spec-draft-hf)
|
||||
if (!plan_spec.model_files.empty()) {
|
||||
add_tasks(plan_spec.model_files, plan_spec.primary, params.speculative.draft.mparams);
|
||||
}
|
||||
|
||||
// handle vocoder plan (e.g. --hf-repo-v)
|
||||
if (!plan_voc.model_files.empty()) {
|
||||
add_tasks(plan_voc.model_files, plan_voc.primary, params.vocoder.model);
|
||||
}
|
||||
|
||||
// run all tasks in parallel
|
||||
if (!params.offline) {
|
||||
common_download_run_tasks(tasks);
|
||||
// if duplicated files are found, only download once (but still call on_done for each task)
|
||||
std::unordered_map<std::string, common_download_task *> unique_tasks;
|
||||
for (auto & task : tasks) {
|
||||
auto it = unique_tasks.find(task.local_path);
|
||||
if (it == unique_tasks.end()) {
|
||||
unique_tasks[task.local_path] = &task;
|
||||
}
|
||||
}
|
||||
std::vector<common_download_task> unique_tasks_vec;
|
||||
for (auto & pair : unique_tasks) {
|
||||
unique_tasks_vec.push_back(*pair.second);
|
||||
}
|
||||
common_download_run_tasks(unique_tasks_vec);
|
||||
}
|
||||
|
||||
// download successful, update params with the downloaded paths
|
||||
@@ -3259,6 +3298,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.sampling.reasoning_budget_message = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-preserve"},
|
||||
{"--no-reasoning-preserve"},
|
||||
"preserve reasoning trace in the full history, not just the last assistant message (default: template default)\n"
|
||||
"compatible with certain templates having 'supports_preserve_reasoning' capability\n"
|
||||
"example: https://docs.z.ai/guides/capabilities/thinking-mode#preserved-thinking",
|
||||
[](common_params & params, bool value) {
|
||||
if (value) {
|
||||
params.default_template_kwargs["preserve_reasoning"] = "true";
|
||||
} else {
|
||||
params.default_template_kwargs["preserve_reasoning"] = "false";
|
||||
}
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_REASONING_PRESERVE"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
@@ -3434,7 +3487,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
[](common_params & params) {
|
||||
params.offline = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_OFFLINE"));
|
||||
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_OFFLINE"));
|
||||
add_opt(common_arg(
|
||||
{"-lv", "--verbosity", "--log-verbosity"}, "N",
|
||||
string_format("Set the verbosity threshold. Messages with a higher verbosity will be ignored. Values:\n"
|
||||
@@ -3711,6 +3764,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
"draft model for speculative decoding (default: unused)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.speculative.draft.mparams.path = value;
|
||||
params.speculative.draft.mparams.hf_file = value; // will be used if --spec-draft-hf is set
|
||||
}
|
||||
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
|
||||
add_opt(common_arg(
|
||||
|
||||
@@ -133,6 +133,8 @@ void common_params_add_preset_options(std::vector<common_arg> & args);
|
||||
|
||||
struct common_models_handler {
|
||||
common_download_hf_plan plan;
|
||||
common_download_hf_plan plan_spec;
|
||||
common_download_hf_plan plan_voc;
|
||||
common_download_opts opts;
|
||||
};
|
||||
|
||||
|
||||
+155
@@ -912,6 +912,10 @@ static std::string common_chat_template_direct_apply_impl(
|
||||
if (inputs.add_generation_prompt) {
|
||||
inp["add_generation_prompt"] = true;
|
||||
}
|
||||
if (inp.contains("preserve_reasoning") && inp["preserve_reasoning"].is_boolean()) {
|
||||
bool enabled = inp["preserve_reasoning"].get<bool>();
|
||||
jinja::caps_apply_preserve_reasoning(ctx, enabled);
|
||||
}
|
||||
|
||||
jinja::global_from_json(ctx, inp, inputs.mark_input);
|
||||
|
||||
@@ -2376,6 +2380,149 @@ static void func_args_not_string(json & messages) {
|
||||
|
||||
}
|
||||
|
||||
// MiniCPM5 format:
|
||||
// - Reasoning: <think>{reasoning}</think> (optional)
|
||||
// - Tool calls: <function name="foo"><param name="bar">value</param></function>
|
||||
static common_chat_params common_chat_params_init_minicpm5(const common_chat_template & tmpl,
|
||||
const autoparser::generation_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
data.prompt = common_chat_template_direct_apply_impl(tmpl, inputs);
|
||||
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs);
|
||||
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
|
||||
data.supports_thinking = true;
|
||||
data.preserved_tokens = {
|
||||
"<function",
|
||||
"<param",
|
||||
"</function>",
|
||||
"</param>",
|
||||
"<think>",
|
||||
"</think>",
|
||||
};
|
||||
|
||||
data.thinking_start_tag = "<think>";
|
||||
data.thinking_end_tag = "</think>";
|
||||
|
||||
data.message_delimiters = {
|
||||
{ COMMON_CHAT_ROLE_ASSISTANT, "<|im_start|>assistant" },
|
||||
{ COMMON_CHAT_ROLE_TOOL, "<|im_start|>user\n<tool_response>" },
|
||||
{ COMMON_CHAT_ROLE_USER, "<|im_start|>user" },
|
||||
{ COMMON_CHAT_ROLE_SYSTEM, "<|im_start|>system" },
|
||||
};
|
||||
|
||||
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
|
||||
auto has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty();
|
||||
auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||
auto include_grammar = has_response_format || (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE);
|
||||
|
||||
if (inputs.has_continuation()) {
|
||||
const auto & msg = inputs.continue_msg;
|
||||
|
||||
data.generation_prompt = "<|im_start|>assistant\n<think>\n" + msg.reasoning_content;
|
||||
if (inputs.continue_final_message == COMMON_CHAT_CONTINUATION_CONTENT) {
|
||||
data.generation_prompt += "\n</think>\n\n" + msg.render_content();
|
||||
}
|
||||
|
||||
data.prompt += data.generation_prompt;
|
||||
}
|
||||
|
||||
auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) {
|
||||
auto generation_prompt = p.literal("<|im_start|>assistant\n");
|
||||
|
||||
auto reasoning = p.eps();
|
||||
if (extract_reasoning) {
|
||||
reasoning = ("<think>" << p.reasoning(p.until("</think>")) << "</think>") + p.space();
|
||||
}
|
||||
|
||||
// Response format parser
|
||||
if (has_response_format) {
|
||||
return generation_prompt + reasoning + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
|
||||
}
|
||||
|
||||
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
// CDATA lets a value carry characters that would otherwise close the tag (e.g.
|
||||
// </param>); capture the inner text only, excluding the CDATA markers.
|
||||
auto string_value = p.choice({
|
||||
p.literal("<![CDATA[") + p.ac(p.tool_arg_string_value(p.until("]]>")) + p.literal("]]>"), "]]>") + p.tool_arg_close(p.literal("</param>")),
|
||||
p.negate(p.literal("< {
|
||||
const auto & function = tool.at("function");
|
||||
const std::string name = function.at("name");
|
||||
auto params = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
|
||||
auto args = p.eps();
|
||||
if (params.contains("properties") && params.at("properties").is_object() && !params.at("properties").empty()) {
|
||||
auto schema_info = common_schema_info();
|
||||
schema_info.resolve_refs(params);
|
||||
|
||||
auto arg_choice = p.choice();
|
||||
for (const auto & [prop_name, prop_schema] : params.at("properties").items()) {
|
||||
auto value_parser = p.eps();
|
||||
if (schema_info.resolves_to_string(prop_schema)) {
|
||||
value_parser = string_value;
|
||||
} else {
|
||||
value_parser = p.tool_arg_json_value(
|
||||
p.schema(p.json(), "tool-" + name + "-arg-" + prop_name + "-schema", prop_schema, false)
|
||||
) + p.tool_arg_close(p.literal("</param>"));
|
||||
}
|
||||
|
||||
auto arg_rule = p.tool_arg(
|
||||
p.tool_arg_open(p.literal("<param name=\"") + p.tool_arg_name(p.literal(prop_name)) + p.literal("\">")) +
|
||||
value_parser
|
||||
);
|
||||
|
||||
arg_choice |= arg_rule;
|
||||
}
|
||||
args = p.zero_or_more(arg_choice + p.space());
|
||||
}
|
||||
|
||||
auto tool_parser = p.tool(
|
||||
p.tool_open(p.literal("<function name=\"") + p.tool_name(p.literal(name)) + p.literal("\">"))
|
||||
<< p.tool_args(args)
|
||||
<< p.tool_close(p.literal("</function>")));
|
||||
|
||||
tool_choice |= p.rule("tool-" + name, tool_parser);
|
||||
});
|
||||
|
||||
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
|
||||
auto tool_calls = p.trigger_rule("tool-call", p.repeat(tool_choice + p.space(), 1, max_calls));
|
||||
|
||||
auto content = p.content(p.until("<function"));
|
||||
|
||||
return generation_prompt + reasoning + content + tool_calls + p.end();
|
||||
}
|
||||
|
||||
return generation_prompt + reasoning + p.content(p.rest()) + p.end();
|
||||
});
|
||||
|
||||
data.parser = parser.save();
|
||||
|
||||
if (include_grammar) {
|
||||
data.grammar_lazy = !(has_response_format || (has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED));
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
auto schema = function.contains("parameters") ? function.at("parameters") : json::object();
|
||||
builder.resolve_refs(schema);
|
||||
});
|
||||
if (has_response_format) {
|
||||
auto schema = inputs.json_schema;
|
||||
builder.resolve_refs(schema);
|
||||
}
|
||||
parser.build_grammar(builder, data.grammar_lazy);
|
||||
});
|
||||
|
||||
data.grammar_triggers = {
|
||||
{ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function" },
|
||||
};
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static json common_chat_extra_context() {
|
||||
json ctx = json::object();
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
@@ -2468,6 +2615,14 @@ std::optional<common_chat_params> common_chat_try_specialized_template(
|
||||
return common_chat_params_init_gemma4(tmpl, params);
|
||||
}
|
||||
|
||||
// MiniCPM5 - XML tool calls with <function name="..."><param name="...">...</param></function>
|
||||
if (src.find("Tool usage guidelines:") != std::string::npos &&
|
||||
src.find("<function name=\"") != std::string::npos &&
|
||||
src.find("<param name=\"") != std::string::npos) {
|
||||
LOG_DBG("Using specialized template: MiniCPM5\n");
|
||||
return common_chat_params_init_minicpm5(tmpl, params);
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
|
||||
+47
-47
@@ -225,7 +225,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
}
|
||||
|
||||
if (!SetPriorityClass(GetCurrentProcess(), p)) {
|
||||
LOG_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
|
||||
COM_WRN("failed to set process priority class %d : (%d)\n", prio, (int) GetLastError());
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -251,7 +251,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {
|
||||
}
|
||||
|
||||
if (setpriority(PRIO_PROCESS, 0, p) != 0) {
|
||||
LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
|
||||
COM_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
@@ -284,14 +284,14 @@ void postprocess_cpu_params(common_cpu_params & cpuparams, const common_cpu_para
|
||||
|
||||
if (n_set && n_set < cpuparams.n_threads) {
|
||||
// Not enough set bits, may experience performance issues.
|
||||
LOG_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
|
||||
COM_WRN("Not enough set bits in CPU mask (%d) to satisfy requested thread count: %d\n", n_set, cpuparams.n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THREADS]) {
|
||||
size_t dash_loc = range.find('-');
|
||||
if (dash_loc == std::string::npos) {
|
||||
LOG_ERR("Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
|
||||
COM_ERR("%s", "Format of CPU range is invalid! Expected [<start>]-[<end>].\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -303,7 +303,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
|
||||
} else {
|
||||
start_i = std::stoull(range.substr(0, dash_loc));
|
||||
if (start_i >= GGML_MAX_N_THREADS) {
|
||||
LOG_ERR("Start index out of bounds!\n");
|
||||
COM_ERR("%s", "Start index out of bounds!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -313,7 +313,7 @@ bool parse_cpu_range(const std::string & range, bool (&boolmask)[GGML_MAX_N_THRE
|
||||
} else {
|
||||
end_i = std::stoull(range.substr(dash_loc + 1));
|
||||
if (end_i >= GGML_MAX_N_THREADS) {
|
||||
LOG_ERR("End index out of bounds!\n");
|
||||
COM_ERR("%s", "End index out of bounds!\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -333,7 +333,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
}
|
||||
|
||||
size_t num_digits = mask.length() - start_i;
|
||||
if (num_digits > 128) num_digits = 128;
|
||||
num_digits = std::min<size_t>(num_digits, 128);
|
||||
|
||||
size_t end_i = num_digits + start_i;
|
||||
|
||||
@@ -348,7 +348,7 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
|
||||
} else if (c >= 'A' && c <= 'F') {
|
||||
id -= 'A' - 10;
|
||||
} else {
|
||||
LOG_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
|
||||
COM_ERR("Invalid hex character '%c' at position %d\n", c, int32_t(i));
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -379,21 +379,21 @@ void common_params_print_info(const common_params & params, bool print_devices)
|
||||
#else
|
||||
const char * build_type = " (debug)";
|
||||
#endif
|
||||
LOG_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
|
||||
COM_TRC("%s: build %d (%s) with %s for %s%s\n", __func__, llama_build_number(), llama_commit(), llama_compiler(), llama_build_target(), build_type);
|
||||
|
||||
LOG_INF("log_info: verbosity = %d (adjust with the `-lv N` CLI arg)\n", common_log_get_verbosity_thold());
|
||||
COM_INF("%s: verbosity = %d (adjust with the `-lv N` CLI arg)\n", __func__, common_log_get_verbosity_thold());
|
||||
|
||||
// device enumeration creates a primary context on CUDA backends, skip it when the caller does not own any device
|
||||
if (print_devices) {
|
||||
LOG_INF("device_info:\n");
|
||||
COM_TRC("%s", "device_info:\n");
|
||||
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
|
||||
auto * dev = ggml_backend_dev_get(i);
|
||||
size_t free, total;
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
LOG_INF(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
COM_TRC(" - %-8s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
COM_TRC("%s\n", common_params_get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
std::string common_params_get_system_info(const common_params & params) {
|
||||
@@ -660,7 +660,7 @@ void string_process_escapes(std::string & input) {
|
||||
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
|
||||
const char * sep = strchr(data, '=');
|
||||
if (sep == nullptr || sep - data >= 128) {
|
||||
LOG_ERR("%s: malformed KV override '%s'\n", __func__, data);
|
||||
COM_ERR("%s: malformed KV override '%s'\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
llama_model_kv_override kvo;
|
||||
@@ -683,20 +683,20 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
|
||||
} else if (std::strcmp(sep, "false") == 0) {
|
||||
kvo.val_bool = false;
|
||||
} else {
|
||||
LOG_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
|
||||
COM_ERR("%s: invalid boolean value for KV override '%s'\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
} else if (strncmp(sep, "str:", 4) == 0) {
|
||||
sep += 4;
|
||||
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
|
||||
if (strlen(sep) > 127) {
|
||||
LOG_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
|
||||
COM_ERR("%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
strncpy(kvo.val_str, sep, 127);
|
||||
kvo.val_str[127] = '\0';
|
||||
} else {
|
||||
LOG_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
|
||||
COM_ERR("%s: invalid type for KV override '%s'\n", __func__, data);
|
||||
return false;
|
||||
}
|
||||
overrides.emplace_back(std::move(kvo));
|
||||
@@ -1199,8 +1199,8 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
if (params.fit_params) {
|
||||
LOG_INF("%s: fitting params to device memory ...\n", __func__);
|
||||
LOG_INF("%s: (for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n", __func__);
|
||||
COM_TRC("%s", "fitting params to device memory ...\n");
|
||||
COM_TRC("%s", "(for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)\n");
|
||||
common_fit_params(params.model.path.c_str(), &mparams, &cparams,
|
||||
params.tensor_split,
|
||||
params.tensor_buft_overrides.data(),
|
||||
@@ -1227,7 +1227,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
llama_adapter_lora_ptr lora;
|
||||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
if (lora == nullptr) {
|
||||
LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
COM_ERR("failed to load lora adapter '%s'\n", la.path.c_str());
|
||||
pimpl->model.reset(model);
|
||||
return;
|
||||
}
|
||||
@@ -1246,14 +1246,14 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have an EOS token, ignoring --ignore-eos\n");
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_TRC("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||
COM_TRC("added %s logit bias = %f\n", common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
@@ -1291,7 +1291,7 @@ common_init_result::common_init_result(common_params & params, bool model_only)
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1328,7 +1328,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
|
||||
llama_model * model = res->model();
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
|
||||
COM_ERR("failed to load model '%s'\n", params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -1338,14 +1338,14 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
|
||||
llama_context * lctx = res->context();
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str());
|
||||
COM_ERR("failed to create context with model '%s'\n", params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||
COM_WRN("%s", "KV cache shifting is not supported for this context, disabling KV cache shifting\n");
|
||||
params.ctx_shift = false;
|
||||
}
|
||||
|
||||
@@ -1374,7 +1374,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
bool ok = true;
|
||||
|
||||
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have a BOS token, reranking will not work\n");
|
||||
ok = false;
|
||||
}
|
||||
|
||||
@@ -1383,10 +1383,10 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
bool has_rerank_prompt = llama_model_chat_template(model, "rerank") != NULL;
|
||||
|
||||
if (!has_eos && !has_sep && !has_rerank_prompt) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have an EOS token, SEP token, or rerank prompt. Reranking will not work\n");
|
||||
ok = false;
|
||||
} else if (!has_eos) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
|
||||
COM_WRN("%s", "vocab does not have an EOS token, using SEP token as fallback\n");
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
@@ -1399,7 +1399,7 @@ common_init_result_ptr common_init_from_params(common_params & params, bool mode
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG_INF("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
COM_TRC("%s", "warming up the model with an empty run - please wait ... (--no-warmup to disable)\n");
|
||||
|
||||
std::vector<llama_token> tmp;
|
||||
llama_token bos = llama_vocab_bos(vocab);
|
||||
@@ -1473,20 +1473,20 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
|
||||
|
||||
int ret = llama_decode(ctx, llama_batch_get_one(tmp.data(), tmp.size()));
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
|
||||
COM_ERR("llama_decode() failed: %d\n", ret);
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
|
||||
goto done;
|
||||
}
|
||||
|
||||
if (llama_n_rs_seq(ctx) > 0) {
|
||||
LOG_INF("%s: the context supports bounded partial sequence removal\n", __func__);
|
||||
COM_TRC("%s", "the context supports bounded partial sequence removal\n");
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_RS;
|
||||
goto done;
|
||||
}
|
||||
|
||||
// try to remove the last tokens
|
||||
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
|
||||
LOG_TRC("%s: the context does not support partial sequence removal\n", __func__);
|
||||
COM_TRC("%s", "the context does not support partial sequence removal\n");
|
||||
res = COMMON_CONTEXT_SEQ_RM_TYPE_FULL;
|
||||
goto done;
|
||||
}
|
||||
@@ -1803,13 +1803,13 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
};
|
||||
struct gguf_context * ctx_gguf = gguf_init_from_file(load_info.fname.c_str(), meta_gguf_params);
|
||||
if (!ctx_gguf) {
|
||||
LOG_ERR("%s: failed to load control vector file from %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("failed to load control vector file from %s\n", load_info.fname.c_str());
|
||||
return result;
|
||||
}
|
||||
|
||||
int32_t n_tensors = gguf_get_n_tensors(ctx_gguf);
|
||||
if (n_tensors == 0) {
|
||||
LOG_WRN("%s: no direction tensors found in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_WRN("no direction tensors found in %s\n", load_info.fname.c_str());
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_tensors; i++) {
|
||||
@@ -1827,23 +1827,23 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
}
|
||||
}
|
||||
if (layer_idx < 0) {
|
||||
LOG_ERR("%s: invalid/unparsable direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid/unparsable direction tensor layer index in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
} else if (layer_idx == 0) {
|
||||
LOG_ERR("%s: invalid (zero) direction tensor layer index in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid (zero) direction tensor layer index in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
|
||||
struct ggml_tensor * tensor = ggml_get_tensor(ctx, name.c_str());
|
||||
if (tensor->type != GGML_TYPE_F32) {
|
||||
LOG_ERR("%s: invalid (non-F32) direction tensor type in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid (non-F32) direction tensor type in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
if (ggml_n_dims(tensor) != 1) {
|
||||
LOG_ERR("%s: invalid (non-1D) direction tensor shape in %s\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("invalid (non-1D) direction tensor shape in %s\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
@@ -1851,7 +1851,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
if (result.n_embd == -1) {
|
||||
result.n_embd = ggml_nelements(tensor);
|
||||
} else if (ggml_nelements(tensor) != result.n_embd) {
|
||||
LOG_ERR("%s: direction tensor in %s does not match previous dimensions\n", __func__, load_info.fname.c_str());
|
||||
COM_ERR("direction tensor in %s does not match previous dimensions\n", load_info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
@@ -1868,7 +1868,7 @@ static common_control_vector_data common_control_vector_load_one(const common_co
|
||||
}
|
||||
|
||||
if (result.n_embd == -1) {
|
||||
LOG_WRN("%s: skipping %s due to invalid direction tensors\n", __func__, load_info.fname.c_str());
|
||||
COM_WRN("skipping %s due to invalid direction tensors\n", load_info.fname.c_str());
|
||||
result.data.clear();
|
||||
}
|
||||
|
||||
@@ -1889,7 +1889,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
||||
break;
|
||||
}
|
||||
if (result.n_embd != -1 && result.n_embd != cur.n_embd) {
|
||||
LOG_ERR("%s: control vectors in %s does not match previous dimensions\n", __func__, info.fname.c_str());
|
||||
COM_ERR("control vectors in %s does not match previous dimensions\n", info.fname.c_str());
|
||||
result.n_embd = -1;
|
||||
break;
|
||||
}
|
||||
@@ -1905,7 +1905,7 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
|
||||
}
|
||||
|
||||
if (result.n_embd == -1) {
|
||||
LOG_ERR("%s: no valid control vector files passed\n", __func__);
|
||||
COM_ERR("%s", "no valid control vector files passed\n");
|
||||
result.data.clear();
|
||||
}
|
||||
|
||||
@@ -2016,13 +2016,13 @@ bool common_prompt_batch_decode(
|
||||
// memory, so we can't just remove the last token from the memory and replay the last token which
|
||||
// is the reason for this logic.
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_tokens_before_last))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
COM_ERR("%s", "failed to eval\n");
|
||||
return false;
|
||||
}
|
||||
n_past += n_tokens_before_last;
|
||||
|
||||
llama_state_save_file(ctx, state_path.data(), all_tokens.data(), all_tokens.size());
|
||||
LOG_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
|
||||
COM_INF("saved session before last token to %s, n_new = %zu\n", state_path.data(), all_tokens.size());
|
||||
|
||||
llama_token last_token = all_tokens.back();
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
@@ -2030,13 +2030,13 @@ bool common_prompt_batch_decode(
|
||||
batch.pos = &pos;
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval last token\n", __func__);
|
||||
COM_ERR("%s", "failed to eval last token\n");
|
||||
return false;
|
||||
}
|
||||
n_past++;
|
||||
} else {
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(all_tokens.data() + offset), n_new))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
COM_ERR("%s", "failed to eval\n");
|
||||
return false;
|
||||
}
|
||||
n_past += n_new;
|
||||
|
||||
+9
-1
@@ -25,6 +25,13 @@
|
||||
#define DIRECTORY_SEPARATOR '/'
|
||||
#endif // _WIN32
|
||||
|
||||
#define COM_DBG(fmt, ...) LOG_DBG("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_TRC(fmt, ...) LOG_TRC("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_INF(fmt, ...) LOG_INF("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_WRN(fmt, ...) LOG_WRN("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_ERR(fmt, ...) LOG_ERR("cmn %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define COM_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
|
||||
#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
|
||||
#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
|
||||
|
||||
@@ -162,6 +169,7 @@ enum common_speculative_type {
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
|
||||
COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, // DFlash speculative decoding
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
|
||||
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
|
||||
@@ -377,7 +385,7 @@ struct common_params_speculative {
|
||||
|
||||
uint32_t need_n_rs_seq() const {
|
||||
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
|
||||
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
|
||||
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3 || t == COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH;
|
||||
});
|
||||
|
||||
return needs_rs_seq ? draft.n_max : 0u;
|
||||
|
||||
+1
-1
@@ -233,7 +233,7 @@ static void common_params_fit_impl(
|
||||
sum_projected_used = dmds_full.back().mb.total();
|
||||
sum_free = dmds_full.back().total;
|
||||
sum_projected_free = sum_free - sum_projected_used;
|
||||
LOG_INF("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
|
||||
LOG_TRC("%s: projected to use %" PRId64 " MiB of host memory vs. %" PRId64 " MiB of total host memory\n",
|
||||
__func__, sum_projected_used/MiB, sum_free/MiB);
|
||||
if (sum_projected_free >= margins[0]) {
|
||||
LOG_TRC("%s: will leave %" PRId64 " >= %" PRId64 " MiB of system memory, no changes needed\n",
|
||||
|
||||
+28
-6
@@ -11,6 +11,11 @@ struct common_http_url {
|
||||
std::string path;
|
||||
};
|
||||
|
||||
// bracket an IPv6 literal host for a URL authority (RFC 3986)
|
||||
static std::string common_http_format_host(const std::string & host) {
|
||||
return host.find(':') != std::string::npos ? "[" + host + "]" : host;
|
||||
}
|
||||
|
||||
static common_http_url common_http_parse_url(const std::string & url) {
|
||||
common_http_url parts;
|
||||
auto scheme_end = url.find("://");
|
||||
@@ -49,11 +54,28 @@ static common_http_url common_http_parse_url(const std::string & url) {
|
||||
parts.path = "/";
|
||||
}
|
||||
|
||||
auto colon_pos = parts.host.find(':');
|
||||
// split the authority into host and optional port, a bracketed IPv6 literal keeps its inner colons (RFC 3986)
|
||||
std::string port_str;
|
||||
if (!parts.host.empty() && parts.host.front() == '[') {
|
||||
auto close = parts.host.find(']');
|
||||
if (close == std::string::npos) {
|
||||
throw std::runtime_error("invalid IPv6 URL authority: " + parts.host);
|
||||
}
|
||||
auto after = parts.host.substr(close + 1);
|
||||
if (!after.empty() && after.front() == ':') {
|
||||
port_str = after.substr(1);
|
||||
}
|
||||
parts.host = parts.host.substr(1, close - 1);
|
||||
} else {
|
||||
auto colon_pos = parts.host.find(':');
|
||||
if (colon_pos != std::string::npos) {
|
||||
port_str = parts.host.substr(colon_pos + 1);
|
||||
parts.host = parts.host.substr(0, colon_pos);
|
||||
}
|
||||
}
|
||||
|
||||
if (colon_pos != std::string::npos) {
|
||||
parts.port = std::stoi(parts.host.substr(colon_pos + 1));
|
||||
parts.host = parts.host.substr(0, colon_pos);
|
||||
if (!port_str.empty()) {
|
||||
parts.port = std::stoi(port_str);
|
||||
} else if (parts.scheme == "http") {
|
||||
parts.port = 80;
|
||||
} else if (parts.scheme == "https") {
|
||||
@@ -83,7 +105,7 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
||||
}
|
||||
#endif
|
||||
|
||||
httplib::Client cli(parts.scheme + "://" + parts.host + ":" + std::to_string(parts.port));
|
||||
httplib::Client cli(parts.scheme + "://" + common_http_format_host(parts.host) + ":" + std::to_string(parts.port));
|
||||
|
||||
if (!parts.user.empty()) {
|
||||
cli.set_basic_auth(parts.user, parts.password);
|
||||
@@ -95,5 +117,5 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
||||
}
|
||||
|
||||
static std::string common_http_show_masked_url(const common_http_url & parts) {
|
||||
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
|
||||
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + common_http_format_host(parts.host) + parts.path;
|
||||
}
|
||||
|
||||
+44
-23
@@ -16,22 +16,34 @@ using json = nlohmann::ordered_json;
|
||||
namespace jinja {
|
||||
|
||||
using caps_json_fn = std::function<json()>;
|
||||
using caps_analyze_fn = std::function<void(bool, value &, value &)>;
|
||||
using caps_ctx_fn = std::function<void(context &)>;
|
||||
using caps_analyze_fn = std::function<void(bool, value &, value &, const std::string &)>;
|
||||
|
||||
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled) {
|
||||
ctx.set_val("preserve_thinking", mk_val<value_bool>(enabled));
|
||||
ctx.set_val("clear_thinking", mk_val<value_bool>(!enabled));
|
||||
ctx.set_val("truncate_history_thinking", mk_val<value_bool>(!enabled));
|
||||
}
|
||||
|
||||
static void caps_try_execute(jinja::program & prog,
|
||||
const caps_json_fn & messages_fn,
|
||||
const caps_ctx_fn & ctx_fn,
|
||||
const caps_json_fn & tools_fn,
|
||||
const caps_analyze_fn & analyze_fn) {
|
||||
context ctx;
|
||||
ctx.is_get_stats = true;
|
||||
jinja::global_from_json(ctx, json{
|
||||
{"messages", messages_fn()},
|
||||
{"tools", tools_fn()},
|
||||
{"tools", tools_fn ? tools_fn() : json::array()},
|
||||
{"bos_token", ""},
|
||||
{"eos_token", ""},
|
||||
{"add_generation_prompt", true}
|
||||
}, true);
|
||||
|
||||
if (ctx_fn) {
|
||||
ctx_fn(ctx);
|
||||
}
|
||||
|
||||
auto messages = ctx.get_val("messages");
|
||||
auto tools = ctx.get_val("tools");
|
||||
|
||||
@@ -49,7 +61,7 @@ static void caps_try_execute(jinja::program & prog,
|
||||
// ignore exceptions during capability analysis
|
||||
}
|
||||
|
||||
analyze_fn(success, messages, tools);
|
||||
analyze_fn(success, messages, tools, result);
|
||||
}
|
||||
|
||||
// for debugging only
|
||||
@@ -109,11 +121,9 @@ caps caps_get(jinja::program & prog) {
|
||||
}
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json{nullptr};
|
||||
},
|
||||
[&](bool success, value & messages, value &) {
|
||||
nullptr, // ctx_fn
|
||||
nullptr, // tools_fn
|
||||
[&](bool success, value & messages, value &, const std::string &) {
|
||||
auto & content = messages->at(0)->at("content");
|
||||
caps_print_stats(content, "messages[0].content");
|
||||
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
|
||||
@@ -145,11 +155,9 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array();
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
nullptr, // ctx_fn
|
||||
nullptr, // tools_fn
|
||||
[&](bool, value & messages, value &, const std::string &) {
|
||||
auto & content = messages->at(0)->at("content");
|
||||
caps_print_stats(content, "messages[0].content");
|
||||
if (!content->stats.used) {
|
||||
@@ -201,6 +209,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
nullptr, // ctx_fn
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
@@ -224,7 +233,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & tools) {
|
||||
[&](bool success, value & messages, value & tools, const std::string &) {
|
||||
if (!success) {
|
||||
return; // Nothing can be inferred
|
||||
}
|
||||
@@ -293,6 +302,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
nullptr, // ctx_fn
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
@@ -316,7 +326,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & tools) {
|
||||
[&](bool success, value & messages, value & tools, const std::string &) {
|
||||
if (!success) {
|
||||
result.supports_tool_calls = false;
|
||||
result.supports_tools = false;
|
||||
@@ -394,6 +404,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
nullptr, // ctx_fn
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array({
|
||||
@@ -417,7 +428,7 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&](bool success, value & messages, value & /*tools*/) {
|
||||
[&](bool success, value & messages, value &, const std::string &) {
|
||||
if (!success) {
|
||||
result.supports_parallel_tool_calls = false;
|
||||
return;
|
||||
@@ -438,11 +449,22 @@ caps caps_get(jinja::program & prog) {
|
||||
JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning");
|
||||
|
||||
// case: preserve reasoning content in chat history
|
||||
const std::string reasoning_placeholder = "<REASONING_CONTENT_PLACEHOLDER>";
|
||||
caps_try_execute(
|
||||
prog,
|
||||
[&]() {
|
||||
// messages
|
||||
return json::array({
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
},
|
||||
{
|
||||
{"role", "assistant"},
|
||||
{"content", "Assistant message"},
|
||||
// check of reasoning_content deeper in the history, not just the last assistant message
|
||||
{"reasoning_content", reasoning_placeholder}
|
||||
},
|
||||
{
|
||||
{"role", "user"},
|
||||
{"content", "User message"}
|
||||
@@ -458,14 +480,13 @@ caps caps_get(jinja::program & prog) {
|
||||
},
|
||||
});
|
||||
},
|
||||
[&]() {
|
||||
// tools
|
||||
return json::array();
|
||||
[&](context & ctx) {
|
||||
caps_apply_preserve_reasoning(ctx, true);
|
||||
},
|
||||
[&](bool, value & messages, value &) {
|
||||
auto & content = messages->at(1)->at("reasoning_content");
|
||||
caps_print_stats(content, "messages[1].reasoning_content");
|
||||
if (content->stats.used) {
|
||||
nullptr, // tools_fn
|
||||
[&](bool, value &, value &, const std::string & output) {
|
||||
// note: we cannot use stats here because the reasoning_content may be used for "if" condition test, but not actually outputted in the final result
|
||||
if (output.find(reasoning_placeholder) != std::string::npos) {
|
||||
result.supports_preserve_reasoning = true;
|
||||
}
|
||||
}
|
||||
|
||||
+5
-1
@@ -12,7 +12,9 @@ struct caps {
|
||||
bool supports_tool_calls = true;
|
||||
bool supports_system_role = true;
|
||||
bool supports_parallel_tool_calls = true;
|
||||
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
|
||||
|
||||
// supports preserve reasoning trace in the full history, not just the last assistant message
|
||||
bool supports_preserve_reasoning = false;
|
||||
|
||||
// one of the 2 content capabilities must be true
|
||||
bool supports_string_content = true;
|
||||
@@ -29,4 +31,6 @@ struct caps {
|
||||
|
||||
caps caps_get(jinja::program & prog);
|
||||
|
||||
void caps_apply_preserve_reasoning(jinja::context & ctx, bool enabled);
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -954,4 +954,50 @@ value keyword_argument_expression::execute_impl(context & ctx) {
|
||||
return mk_val<value_kwarg>(k, v);
|
||||
}
|
||||
|
||||
std::string runtime::debug_dump_program(const program & prog, const std::string & src) {
|
||||
std::ostringstream oss;
|
||||
size_t lvl = 0;
|
||||
context ctx;
|
||||
ctx.src.reset(new std::string(src));
|
||||
|
||||
auto indent = [](size_t lvl) -> std::string {
|
||||
return std::string(lvl * 2, ' ');
|
||||
};
|
||||
|
||||
ctx.visitor = [&](bool is_leaf, statement * node, std::vector<visitor_pair> children) {
|
||||
oss << indent(lvl) << node->type() << ":\n";
|
||||
lvl++;
|
||||
if (is_leaf) {
|
||||
const auto & pos = node->pos;
|
||||
oss << indent(lvl) << "(leaf) at " << get_line_col(src, pos) << " in source:\n";
|
||||
std::string snippet = peak_source(src, pos);
|
||||
string_replace_all(snippet, "\n", "\n" + indent(lvl));
|
||||
oss << indent(lvl) << snippet << "\n";
|
||||
} else {
|
||||
for (auto & [label, children_vec] : children) {
|
||||
oss << indent(lvl) << label << ":\n";
|
||||
lvl++;
|
||||
if (children_vec.empty()) {
|
||||
oss << indent(lvl) << "<empty>\n\n";
|
||||
} else {
|
||||
for (auto * child : children_vec) {
|
||||
if (!child) {
|
||||
continue;
|
||||
}
|
||||
child->visit(ctx);
|
||||
}
|
||||
}
|
||||
lvl--;
|
||||
}
|
||||
}
|
||||
lvl--;
|
||||
};
|
||||
|
||||
for (const auto & stmt : prog.body) {
|
||||
stmt->visit(ctx);
|
||||
}
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -47,12 +47,19 @@ const T * cast_stmt(const statement_ptr & ptr) {
|
||||
// not thread-safe
|
||||
void enable_debug(bool enable);
|
||||
|
||||
// for visiting AST nodes
|
||||
// function signature: void(bool is_leaf, statement * node, pair of <label, children>)
|
||||
using visitor_pair = std::pair<std::string, std::vector<statement *>>;
|
||||
using visitor_fn = std::function<void(bool, statement *, std::vector<visitor_pair>)>;
|
||||
|
||||
struct context {
|
||||
std::shared_ptr<std::string> src; // for debugging; use shared_ptr to avoid copying on scope creation
|
||||
std::time_t current_time; // for functions that need current time
|
||||
|
||||
bool is_get_stats = false; // whether to collect stats
|
||||
|
||||
visitor_fn visitor;
|
||||
|
||||
// src is optional, used for error reporting
|
||||
context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
|
||||
env = mk_val<value_object>();
|
||||
@@ -99,6 +106,15 @@ private:
|
||||
value_object env;
|
||||
};
|
||||
|
||||
// utils for visiting AST nodes
|
||||
static std::vector<statement *> stmts_to_ptr(const statements & stmts) {
|
||||
std::vector<statement *> children;
|
||||
for (const auto & stmt : stmts) {
|
||||
children.push_back(stmt.get());
|
||||
}
|
||||
return children;
|
||||
}
|
||||
|
||||
/**
|
||||
* Base class for all nodes in the AST.
|
||||
*/
|
||||
@@ -106,6 +122,7 @@ struct statement {
|
||||
size_t pos; // position in source, for debugging
|
||||
virtual ~statement() = default;
|
||||
virtual std::string type() const { return "Statement"; }
|
||||
virtual void visit(context & ctx) { ctx.visitor(true, this, {}); }
|
||||
|
||||
// execute_impl must be overridden by derived classes
|
||||
virtual value execute_impl(context &) { throw_exec_error(); }
|
||||
@@ -166,6 +183,13 @@ struct if_statement : public statement {
|
||||
|
||||
std::string type() const override { return "If"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"test", {test.get()}},
|
||||
{"body", stmts_to_ptr(body)},
|
||||
{"alternate", stmts_to_ptr(alternate)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct identifier;
|
||||
@@ -190,6 +214,14 @@ struct for_statement : public statement {
|
||||
|
||||
std::string type() const override { return "For"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"loopvar", {loopvar.get()}},
|
||||
{"iterable", {iterable.get()}},
|
||||
{"body", stmts_to_ptr(body)},
|
||||
{"default_block", stmts_to_ptr(default_block)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct break_statement : public statement {
|
||||
@@ -241,6 +273,13 @@ struct set_statement : public statement {
|
||||
|
||||
std::string type() const override { return "Set"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"assignee", {assignee.get()}},
|
||||
{"value", {val.get()}},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct macro_statement : public statement {
|
||||
@@ -256,6 +295,13 @@ struct macro_statement : public statement {
|
||||
|
||||
std::string type() const override { return "Macro"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"name", {name.get()}},
|
||||
{"args", stmts_to_ptr(args)},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct comment_statement : public statement {
|
||||
@@ -289,6 +335,12 @@ struct member_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "MemberExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"object", {object.get()}},
|
||||
{"property", {property.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct call_expression : public expression {
|
||||
@@ -302,6 +354,12 @@ struct call_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "CallExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"callee", {callee.get()}},
|
||||
{"args", stmts_to_ptr(args)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -405,6 +463,12 @@ struct binary_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "BinaryExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"left", {left.get()}},
|
||||
{"right", {right.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -431,6 +495,12 @@ struct filter_expression : public expression {
|
||||
|
||||
std::string type() const override { return "FilterExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"operand", {operand.get()}},
|
||||
{"filter", {filter.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct filter_statement : public statement {
|
||||
@@ -443,6 +513,12 @@ struct filter_statement : public statement {
|
||||
}
|
||||
std::string type() const override { return "FilterStatement"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"filter", {filter.get()}},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -468,6 +544,12 @@ struct select_expression : public expression {
|
||||
}
|
||||
return lhs->execute_impl(ctx);
|
||||
}
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"lhs", {lhs.get()}},
|
||||
{"test", {test.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -486,6 +568,12 @@ struct test_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "TestExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"operand", {operand.get()}},
|
||||
{"test", {test.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -501,6 +589,11 @@ struct unary_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "UnaryExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"argument", {argument.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct slice_expression : public expression {
|
||||
@@ -518,6 +611,13 @@ struct slice_expression : public expression {
|
||||
[[noreturn]] value execute_impl(context &) override {
|
||||
throw std::runtime_error("must be handled by MemberExpression");
|
||||
}
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"start_expr", {start_expr.get()}},
|
||||
{"stop_expr", {stop_expr.get()}},
|
||||
{"step_expr", {step_expr.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct keyword_argument_expression : public expression {
|
||||
@@ -531,6 +631,12 @@ struct keyword_argument_expression : public expression {
|
||||
}
|
||||
std::string type() const override { return "KeywordArgumentExpression"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"key", {key.get()}},
|
||||
{"val", {val.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct spread_expression : public expression {
|
||||
@@ -539,6 +645,11 @@ struct spread_expression : public expression {
|
||||
chk_type<expression>(this->argument);
|
||||
}
|
||||
std::string type() const override { return "SpreadExpression"; }
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"argument", {argument.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct call_statement : public statement {
|
||||
@@ -553,6 +664,13 @@ struct call_statement : public statement {
|
||||
}
|
||||
std::string type() const override { return "CallStatement"; }
|
||||
value execute_impl(context & ctx) override;
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"call", {call.get()}},
|
||||
{"caller_args", stmts_to_ptr(caller_args)},
|
||||
{"body", stmts_to_ptr(body)}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct ternary_expression : public expression {
|
||||
@@ -575,6 +693,13 @@ struct ternary_expression : public expression {
|
||||
return false_expr->execute(ctx);
|
||||
}
|
||||
}
|
||||
void visit(context & ctx) override {
|
||||
ctx.visitor(false, this, {
|
||||
{"condition", {condition.get()}},
|
||||
{"true_expr", {true_expr.get()}},
|
||||
{"false_expr", {false_expr.get()}}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
struct raised_exception : public std::exception {
|
||||
@@ -648,6 +773,8 @@ struct runtime {
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
static std::string debug_dump_program(const program & prog, const std::string & src);
|
||||
};
|
||||
|
||||
} // namespace jinja
|
||||
|
||||
@@ -1108,6 +1108,50 @@ const func_builtins & value_array_t::get_builtins() const {
|
||||
std::reverse(arr.begin(), arr.end());
|
||||
return is_val<value_tuple>(val) ? mk_val<value_tuple>(std::move(arr)) : mk_val<value_array>(std::move(arr));
|
||||
}},
|
||||
{"min", [](const func_args & args) -> value {
|
||||
args.ensure_count(1, 4);
|
||||
args.ensure_vals<value_array>();
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
if (!attribute->is_undefined()) {
|
||||
throw not_implemented_exception("min: attribute not implemented");
|
||||
}
|
||||
// FIXME: min is currently always case sensitive
|
||||
(void) val_case;
|
||||
const auto & arr = args.get_pos(0)->as_array();
|
||||
if (arr.empty()) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
value result = arr[0];
|
||||
for (size_t i = 1; i < arr.size(); ++i) {
|
||||
if (value_compare(arr[i], result, value_compare_op::lt)) {
|
||||
result = arr[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}},
|
||||
{"max", [](const func_args & args) -> value {
|
||||
args.ensure_count(1, 4);
|
||||
args.ensure_vals<value_array>();
|
||||
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
|
||||
value attribute = args.get_kwarg_or_pos("attribute", 2);
|
||||
if (!attribute->is_undefined()) {
|
||||
throw not_implemented_exception("max: attribute not implemented");
|
||||
}
|
||||
// FIXME: max is currently always case sensitive
|
||||
(void) val_case;
|
||||
const auto & arr = args.get_pos(0)->as_array();
|
||||
if (arr.empty()) {
|
||||
return mk_val<value_undefined>();
|
||||
}
|
||||
value result = arr[0];
|
||||
for (size_t i = 1; i < arr.size(); ++i) {
|
||||
if (value_compare(arr[i], result, value_compare_op::gt)) {
|
||||
result = arr[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}},
|
||||
{"unique", array_unique_not_implemented},
|
||||
};
|
||||
return builtins;
|
||||
|
||||
+29
-4
@@ -7,6 +7,7 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <filesystem>
|
||||
#include <regex>
|
||||
|
||||
static std::string rm_leading_dashes(const std::string & str) {
|
||||
size_t pos = 0;
|
||||
@@ -16,6 +17,23 @@ static std::string rm_leading_dashes(const std::string & str) {
|
||||
return str.substr(pos);
|
||||
}
|
||||
|
||||
static std::string canonical_tag(const std::string & tag) {
|
||||
static const std::regex re_tag("[-.]([A-Z0-9_]+)$", std::regex::icase);
|
||||
std::smatch m;
|
||||
if (std::regex_search(tag, m, re_tag)) {
|
||||
std::string canon = m[1].str();
|
||||
for (char & c : canon) {
|
||||
c = (char) std::toupper((unsigned char) c);
|
||||
}
|
||||
return canon;
|
||||
}
|
||||
std::string upper = tag;
|
||||
for (char & c : upper) {
|
||||
c = (char) std::toupper((unsigned char) c);
|
||||
}
|
||||
return upper;
|
||||
}
|
||||
|
||||
std::vector<std::string> common_preset::to_args(const std::string & bin_path) const {
|
||||
std::vector<std::string> args;
|
||||
|
||||
@@ -270,11 +288,18 @@ common_presets common_preset_context::load_from_ini(const std::string & path, co
|
||||
|
||||
for (auto section : ini_data) {
|
||||
common_preset preset;
|
||||
if (section.first.empty()) {
|
||||
preset.name = COMMON_PRESET_DEFAULT_NAME;
|
||||
} else {
|
||||
preset.name = section.first;
|
||||
std::string section_name = section.first.empty() ? std::string(COMMON_PRESET_DEFAULT_NAME) : section.first;
|
||||
if (section_name != "*" && section_name != COMMON_PRESET_DEFAULT_NAME) {
|
||||
auto colon_idx = section_name.rfind(':');
|
||||
if (colon_idx != std::string::npos) {
|
||||
std::string tag = section_name.substr(colon_idx + 1);
|
||||
std::string canon_tag = canonical_tag(tag);
|
||||
if (canon_tag != tag) {
|
||||
section_name = section_name.substr(0, colon_idx + 1) + canon_tag;
|
||||
}
|
||||
}
|
||||
}
|
||||
preset.name = section_name;
|
||||
LOG_DBG("loading preset: %s\n", preset.name.c_str());
|
||||
for (const auto & [key, value] : section.second) {
|
||||
if (key == "version") {
|
||||
|
||||
+10
-10
@@ -65,12 +65,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
if (ctx->start_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_COUNTING;
|
||||
ctx->remaining = ctx->budget;
|
||||
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
|
||||
COM_TRC("activated, budget=%d tokens\n", ctx->budget);
|
||||
|
||||
if (ctx->remaining <= 0) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
||||
COM_TRC("%s", "budget=0, forcing immediately\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -80,7 +80,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
{
|
||||
if (ctx->end_matcher.advance(token)) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: deactivated (natural end)\n");
|
||||
COM_TRC("%s", "deactivated (natural end)\n");
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
|
||||
COM_TRC("%s", "UTF-8 complete, now forcing end sequence\n");
|
||||
}
|
||||
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
|
||||
ctx->remaining--;
|
||||
@@ -104,11 +104,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
|
||||
COM_TRC("%s", "budget exhausted, forcing end sequence\n");
|
||||
} else {
|
||||
ctx->state = REASONING_BUDGET_WAITING_UTF8;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
|
||||
COM_TRC("%s", "budget exhausted, waiting for UTF-8 completion\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -118,7 +118,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->force_pos++;
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
ctx->state = REASONING_BUDGET_DONE;
|
||||
LOG_INF("reasoning-budget: forced sequence complete, done\n");
|
||||
COM_TRC("%s", "forced sequence complete, done\n");
|
||||
}
|
||||
break;
|
||||
case REASONING_BUDGET_DONE:
|
||||
@@ -128,12 +128,12 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
||||
ctx->state = REASONING_BUDGET_COUNTING;
|
||||
ctx->remaining = ctx->budget;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);
|
||||
COM_TRC("re-activated on new start tag, budget=%d tokens\n", ctx->budget);
|
||||
|
||||
if (ctx->remaining <= 0) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
||||
COM_TRC("%s", "budget=0, forcing immediately\n");
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -264,7 +264,7 @@ bool common_reasoning_budget_force(struct llama_sampler * smpl) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n");
|
||||
COM_TRC("%s", "forced into forcing state (manual transition)\n");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
#include "regex-partial.h"
|
||||
#include "common.h"
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
|
||||
common_regex::common_regex(const std::string & pattern) :
|
||||
pattern(pattern),
|
||||
rx(pattern),
|
||||
rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
|
||||
|
||||
common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
|
||||
std::smatch match;
|
||||
if (pos > input.size()) {
|
||||
throw std::runtime_error("Position out of bounds");
|
||||
}
|
||||
auto start = input.begin() + pos;
|
||||
auto found = as_match
|
||||
? std::regex_match(start, input.end(), match, rx)
|
||||
: std::regex_search(start, input.end(), match, rx);
|
||||
if (found) {
|
||||
common_regex_match res;
|
||||
res.type = COMMON_REGEX_MATCH_TYPE_FULL;
|
||||
for (size_t i = 0; i < match.size(); ++i) {
|
||||
auto begin = pos + match.position(i);
|
||||
res.groups.emplace_back(begin, begin + match.length(i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
std::match_results<std::string::const_reverse_iterator> srmatch;
|
||||
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
|
||||
auto group = srmatch[1].str();
|
||||
if (group.length() != 0) {
|
||||
auto it = srmatch[1].second.base();
|
||||
// auto position = static_cast<size_t>(std::distance(input.begin(), it));
|
||||
if ((!as_match) || it == input.begin()) {
|
||||
common_regex_match res;
|
||||
res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
|
||||
const size_t begin = std::distance(input.begin(), it);
|
||||
const size_t end = input.size();
|
||||
if (begin == std::string::npos || end == std::string::npos || begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
res.groups.push_back({begin, end});
|
||||
return res;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/*
|
||||
Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
|
||||
|
||||
Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
|
||||
to see if a string ends with a partial regex match, but but it's not in std::regex yet.
|
||||
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
|
||||
|
||||
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
|
||||
- /a|b/ -> ^(a|b)
|
||||
- /a*?/ -> error, could match ""
|
||||
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
|
||||
- /.*?ab/ -> ^((?:b)?a) (omit .*)
|
||||
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
|
||||
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
|
||||
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
|
||||
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
|
||||
|
||||
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
|
||||
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
|
||||
*/
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern) {
|
||||
auto it = pattern.begin();
|
||||
const auto end = pattern.end();
|
||||
|
||||
std::function<std::string()> process = [&]() {
|
||||
std::vector<std::vector<std::string>> alternatives(1);
|
||||
std::vector<std::string> * sequence = &alternatives.back();
|
||||
|
||||
while (it != end) {
|
||||
if (*it == '[') {
|
||||
auto start = it;
|
||||
++it;
|
||||
while (it != end) {
|
||||
if ((*it == '\\') && (++it != end)) {
|
||||
++it;
|
||||
} else if ((it != end) && (*it == ']')) {
|
||||
break;
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '[' in pattern");
|
||||
}
|
||||
++it;
|
||||
sequence->push_back(std::string(start, it));
|
||||
} else if (*it == '*' || *it == '?' || *it == '+') {
|
||||
if (sequence->empty()) {
|
||||
throw std::runtime_error("Quantifier without preceding element");
|
||||
}
|
||||
sequence->back() += *it;
|
||||
auto is_star = *it == '*';
|
||||
++it;
|
||||
if (is_star) {
|
||||
if (it != end && *it == '?') {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
} else if (*it == '{') {
|
||||
if (sequence->empty()) {
|
||||
throw std::runtime_error("Repetition without preceding element");
|
||||
}
|
||||
++it;
|
||||
auto start = it;
|
||||
while (it != end && *it != '}') {
|
||||
++it;
|
||||
}
|
||||
if (it == end) {
|
||||
throw std::runtime_error("Unmatched '{' in pattern");
|
||||
}
|
||||
auto parts = string_split(std::string(start, it), ",");
|
||||
++it;
|
||||
if (parts.size() > 2) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
}
|
||||
|
||||
auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
|
||||
if (s.empty()) {
|
||||
return def;
|
||||
}
|
||||
return std::stoi(s);
|
||||
};
|
||||
auto min = parseOptInt(parts[0], 0);
|
||||
auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
|
||||
if (min && max && *max < *min) {
|
||||
throw std::runtime_error("Invalid repetition range in pattern");
|
||||
}
|
||||
// Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
|
||||
auto part = sequence->back();
|
||||
sequence->pop_back();
|
||||
for (int i = 0; i < *min; i++) {
|
||||
sequence->push_back(part);
|
||||
}
|
||||
if (max) {
|
||||
for (int i = *min; i < *max; i++) {
|
||||
sequence->push_back(part + "?");
|
||||
}
|
||||
} else {
|
||||
sequence->push_back(part + "*");
|
||||
}
|
||||
} else if (*it == '(') {
|
||||
++it;
|
||||
if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
|
||||
it += 2;
|
||||
}
|
||||
auto sub = process();
|
||||
if (*it != ')') {
|
||||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
++it;
|
||||
auto & part = sequence->emplace_back("(?:");
|
||||
part += sub;
|
||||
part += ")";
|
||||
} else if (*it == ')') {
|
||||
break;
|
||||
} else if (*it == '|') {
|
||||
++it;
|
||||
alternatives.emplace_back();
|
||||
sequence = &alternatives.back();
|
||||
} else if (*it == '\\' && (++it != end)) {
|
||||
auto str = std::string("\\") + *it;
|
||||
sequence->push_back(str);
|
||||
++it;
|
||||
} else if (it != end) {
|
||||
sequence->push_back(std::string(1, *it));
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
|
||||
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
|
||||
// We'll do the outermost capturing group and final .* in the enclosing function.
|
||||
std::vector<std::string> res_alts;
|
||||
for (const auto & parts : alternatives) {
|
||||
auto & res = res_alts.emplace_back();
|
||||
for (size_t i = 0; i < parts.size() - 1; i++) {
|
||||
res += "(?:";
|
||||
}
|
||||
for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
|
||||
res += *it;
|
||||
if (it != parts.rend() - 1) {
|
||||
res += ")?";
|
||||
}
|
||||
}
|
||||
}
|
||||
return string_join(res_alts, "|");
|
||||
};
|
||||
auto res = process();
|
||||
if (it != end) {
|
||||
throw std::runtime_error("Unmatched '(' in pattern");
|
||||
}
|
||||
|
||||
return "^(" + res + ")";
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <regex>
|
||||
#include <string>
|
||||
|
||||
enum common_regex_match_type {
|
||||
COMMON_REGEX_MATCH_TYPE_NONE,
|
||||
COMMON_REGEX_MATCH_TYPE_PARTIAL,
|
||||
COMMON_REGEX_MATCH_TYPE_FULL,
|
||||
};
|
||||
|
||||
struct common_string_range {
|
||||
size_t begin;
|
||||
size_t end;
|
||||
common_string_range(size_t begin, size_t end) : begin(begin), end(end) {
|
||||
if (begin > end) {
|
||||
throw std::runtime_error("Invalid range");
|
||||
}
|
||||
}
|
||||
// prevent default ctor
|
||||
common_string_range() = delete;
|
||||
bool empty() const {
|
||||
return begin == end;
|
||||
}
|
||||
bool operator==(const common_string_range & other) const {
|
||||
return begin == other.begin && end == other.end;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_regex_match {
|
||||
common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE;
|
||||
std::vector<common_string_range> groups;
|
||||
|
||||
bool operator==(const common_regex_match & other) const {
|
||||
return type == other.type && groups == other.groups;
|
||||
}
|
||||
bool operator!=(const common_regex_match & other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
};
|
||||
|
||||
class common_regex {
|
||||
std::string pattern;
|
||||
std::regex rx;
|
||||
std::regex rx_reversed_partial;
|
||||
|
||||
public:
|
||||
explicit common_regex(const std::string & pattern);
|
||||
|
||||
common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const;
|
||||
|
||||
const std::string & str() const { return pattern; }
|
||||
};
|
||||
|
||||
// For testing only (pretty print of failures).
|
||||
std::string regex_to_reversed_partial_regex(const std::string & pattern);
|
||||
+380
-65
@@ -18,6 +18,13 @@
|
||||
#include <map>
|
||||
#include <cinttypes>
|
||||
|
||||
#define SPC_DBG(fmt, ...) LOG_DBG("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_TRC(fmt, ...) LOG_TRC("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_INF(fmt, ...) LOG_INF("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_WRN(fmt, ...) LOG_WRN("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_ERR(fmt, ...) LOG_ERR("spec %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
||||
#define SPC_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__)
|
||||
|
||||
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
@@ -26,6 +33,7 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
|
||||
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
|
||||
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
|
||||
{"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP},
|
||||
{"draft-dflash", COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH},
|
||||
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
|
||||
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
|
||||
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
|
||||
@@ -60,21 +68,20 @@ static bool common_speculative_are_compatible(
|
||||
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
|
||||
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
|
||||
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
|
||||
SPC_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
|
||||
|
||||
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
|
||||
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
|
||||
SPC_DBG("vocab_type dft: %d\n", vocab_type_dft);
|
||||
|
||||
if (vocab_type_tgt != vocab_type_dft) {
|
||||
LOG_WRN("%s: draft model vocab type must match target model to use speculation but "
|
||||
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
|
||||
SPC_WRN("draft model vocab type must match target model to use speculation but "
|
||||
"vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
|
||||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
|
||||
LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
SPC_WRN("draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
|
||||
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
|
||||
return false;
|
||||
@@ -82,8 +89,7 @@ static bool common_speculative_are_compatible(
|
||||
|
||||
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
|
||||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
|
||||
LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
__func__,
|
||||
SPC_WRN("draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
|
||||
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
|
||||
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
|
||||
return false;
|
||||
@@ -97,8 +103,8 @@ static bool common_speculative_are_compatible(
|
||||
: n_vocab_dft - n_vocab_tgt;
|
||||
|
||||
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
||||
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
||||
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||
SPC_DBG("draft model vocab must closely match target model to use speculation but "
|
||||
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
||||
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
||||
return false;
|
||||
}
|
||||
@@ -108,8 +114,8 @@ static bool common_speculative_are_compatible(
|
||||
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
|
||||
|
||||
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
||||
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
|
||||
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
|
||||
SPC_DBG("draft model vocab must match target model to use speculation but "
|
||||
"token %d content differs - target '%s', draft '%s'\n", i,
|
||||
common_token_to_piece(vocab_tgt, i).c_str(),
|
||||
common_token_to_piece(vocab_dft, i).c_str());
|
||||
return false;
|
||||
@@ -186,9 +192,9 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
|
||||
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'draft-simple'\n");
|
||||
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f\n", this->params.n_max, this->params.n_min, this->params.p_min);
|
||||
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
|
||||
this->params.n_gpu_layers,
|
||||
ggml_type_name(this->params.cache_type_k),
|
||||
ggml_type_name(this->params.cache_type_v),
|
||||
@@ -228,16 +234,16 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
}
|
||||
|
||||
const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
|
||||
LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt);
|
||||
SPC_DBG("vocab_cmpt = %d\n", vocab_cmpt);
|
||||
|
||||
if (!vocab_cmpt) {
|
||||
LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__);
|
||||
SPC_ERR("%s", "the target and draft vocabs are not compatible\n");
|
||||
|
||||
throw std::runtime_error("draft model vocab type must match target model to use speculation");
|
||||
}
|
||||
|
||||
if (n_seq != llama_n_seq_max(ctx_dft)) {
|
||||
LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft));
|
||||
SPC_ERR("n_seq mismatch: %d != %d\n", n_seq, llama_n_seq_max(ctx_dft));
|
||||
|
||||
throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq");
|
||||
}
|
||||
@@ -257,7 +263,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
const int ret = llama_decode(ctx_dft, batch);
|
||||
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret);
|
||||
SPC_ERR("failed to decode draft batch, ret = %d\n", ret);
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -290,7 +296,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
SPC_ERR("llama_decode returned %d\n", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -314,7 +320,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
@@ -354,7 +360,7 @@ struct common_speculative_impl_draft_simple : public common_speculative_impl {
|
||||
// evaluate the drafted tokens on the draft model
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -449,8 +455,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
|
||||
, params(params.draft)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
|
||||
SPC_TRC("%s", "adding speculative implementation 'draft-eagle3'\n");
|
||||
SPC_TRC("- n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
|
||||
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
@@ -493,7 +499,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
|
||||
|
||||
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
|
||||
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
|
||||
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
|
||||
llama_sampler_free(chain);
|
||||
chain = nullptr;
|
||||
}
|
||||
@@ -548,9 +554,9 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
if (pos_max < N - 2) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
|
||||
SPC_WRN("ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
|
||||
"Drafts may degrade.\n",
|
||||
__func__, (int) pos_max, N - 2);
|
||||
(int) pos_max, N - 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -621,8 +627,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
};
|
||||
const int32_t rc = llama_encode(ctx_dft, enc_batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
|
||||
__func__, rc, (int) n_chunk, (int) i);
|
||||
SPC_ERR("llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
|
||||
rc, (int) n_chunk, (int) i);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -692,8 +698,8 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
if (batch.n_tokens > 0) {
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
|
||||
__func__, rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
|
||||
SPC_ERR("llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
|
||||
rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -744,7 +750,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
SPC_ERR("llama_decode returned %d\n", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -770,7 +776,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
@@ -809,7 +815,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
|
||||
ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -893,6 +899,305 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
|
||||
}
|
||||
};
|
||||
|
||||
// DFlash: block-diffusion drafting with a draft-side KV cache injection
|
||||
struct common_speculative_impl_draft_dflash : public common_speculative_impl {
|
||||
common_params_speculative_draft params;
|
||||
|
||||
llama_batch batch; // noise tokens
|
||||
llama_batch batch_inject; // target features for KV cache injection
|
||||
|
||||
std::vector<common_sampler_ptr> smpls;
|
||||
|
||||
int32_t n_embd_dec = 0; // draft hidden size
|
||||
int32_t n_embd_enc = 0; // target_layer_ids_n * target_hidden_size
|
||||
int32_t n_embd_tgt = 0; // target model hidden size
|
||||
|
||||
int32_t block_size = 0;
|
||||
llama_token mask_token_id = 0;
|
||||
|
||||
const int32_t * target_layer_ids = nullptr; // model_dft's extract layer indices
|
||||
uint32_t target_layer_ids_n = 0;
|
||||
|
||||
// scratch buffer for concatenated target features [n_tokens, n_embd_enc]
|
||||
std::vector<float> features_buf;
|
||||
|
||||
common_speculative_impl_draft_dflash(const common_params_speculative & params, uint32_t n_seq)
|
||||
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, n_seq)
|
||||
, params(params.draft)
|
||||
{
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
GGML_ASSERT(ctx_tgt && ctx_dft && "DFlash requires ctx_tgt and ctx_dft to be set");
|
||||
|
||||
const llama_model * model_dft = llama_get_model(ctx_dft);
|
||||
const llama_model * model_tgt = llama_get_model(ctx_tgt);
|
||||
|
||||
target_layer_ids = llama_model_target_layer_ids (model_dft);
|
||||
target_layer_ids_n = llama_model_target_layer_ids_n(model_dft);
|
||||
GGML_ASSERT(target_layer_ids_n > 0 && "DFlash model has no target_layer_ids");
|
||||
|
||||
n_embd_tgt = llama_model_n_embd(model_tgt);
|
||||
n_embd_dec = llama_model_n_embd(model_dft);
|
||||
n_embd_enc = (int32_t) target_layer_ids_n * n_embd_tgt;
|
||||
|
||||
// read the trained block size from the dflash.block_size metadata key
|
||||
block_size = 16;
|
||||
{
|
||||
char buf[32] = {};
|
||||
if (llama_model_meta_val_str(model_dft, "dflash.block_size", buf, sizeof(buf)) >= 0) {
|
||||
block_size = std::atoi(buf);
|
||||
}
|
||||
}
|
||||
mask_token_id = llama_vocab_mask(llama_model_get_vocab(model_dft));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-dflash'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
|
||||
LOG_INF("%s: - block_size=%d, mask_token_id=%d, n_extract=%u\n", __func__, block_size, mask_token_id, target_layer_ids_n);
|
||||
|
||||
// DFlash input is [id_last, <mask> * (block_size-1)], so it can draft at most block_size-1 tokens per step
|
||||
if (this->params.n_max > block_size - 1 || this->params.n_min > block_size - 1) {
|
||||
LOG_WRN("%s: requested draft size (n_max=%d, n_min=%d) exceeds the trained DFlash block size %d -- clamping to %d\n",
|
||||
__func__, this->params.n_max, this->params.n_min, block_size, block_size - 1);
|
||||
this->params.n_max = std::min(this->params.n_max, block_size - 1);
|
||||
this->params.n_min = std::min(this->params.n_min, block_size - 1);
|
||||
}
|
||||
|
||||
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, n_seq);
|
||||
batch_inject = llama_batch_init(llama_n_batch(ctx_dft), n_embd_dec, n_seq);
|
||||
|
||||
smpls.resize(n_seq);
|
||||
for (auto & s : smpls) {
|
||||
common_params_sampling sparams;
|
||||
sparams.no_perf = false;
|
||||
sparams.top_k = 10;
|
||||
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
|
||||
s.reset(common_sampler_init(model_dft, sparams));
|
||||
}
|
||||
|
||||
// turn on extraction of the target layers' input embeddings
|
||||
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
|
||||
llama_set_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k], true);
|
||||
}
|
||||
|
||||
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
|
||||
llama_set_causal_attn(ctx_dft, false); // DFlash needs non-causal attention
|
||||
}
|
||||
|
||||
~common_speculative_impl_draft_dflash() override {
|
||||
llama_batch_free(batch);
|
||||
llama_batch_free(batch_inject);
|
||||
}
|
||||
|
||||
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int32_t N = (int32_t) prompt.size();
|
||||
if (N <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(params.ctx_dft), seq_id);
|
||||
if (pos_max < N - 1) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - process() did not run on every prefill ubatch. "
|
||||
"Drafts may degrade.\n",
|
||||
__func__, (int) pos_max, N - 1);
|
||||
}
|
||||
}
|
||||
|
||||
bool process(const llama_batch & batch_in) override {
|
||||
if (batch_in.n_tokens <= 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const int32_t n_tokens = batch_in.n_tokens;
|
||||
|
||||
// per-seq inclusive batch range (assumes each seq's tokens are contiguous in the batch)
|
||||
std::vector<int32_t> i_batch_beg(n_seq, -1);
|
||||
std::vector<int32_t> i_batch_end(n_seq, -1);
|
||||
for (int32_t k = 0; k < n_tokens; ++k) {
|
||||
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
|
||||
const llama_seq_id seq_id = batch_in.seq_id[k][0];
|
||||
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
|
||||
continue;
|
||||
}
|
||||
i_batch_end[seq_id] = k;
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
i_batch_beg[seq_id] = k;
|
||||
}
|
||||
}
|
||||
|
||||
auto * ctx_tgt = this->params.ctx_tgt;
|
||||
auto * ctx_dft = this->params.ctx_dft;
|
||||
|
||||
const int32_t n_ubatch = (int32_t) llama_n_ubatch(ctx_dft);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_batch_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1;
|
||||
|
||||
for (int32_t offset = 0; offset < n_rows; offset += n_ubatch) {
|
||||
const int32_t n_chunk = std::min(n_ubatch, n_rows - offset);
|
||||
|
||||
// gather this chunk's target features, interleaved by extract layer
|
||||
features_buf.resize((size_t) n_chunk * n_embd_enc);
|
||||
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
|
||||
const float * layer = llama_get_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k]);
|
||||
if (!layer) {
|
||||
GGML_ABORT("DFlash: target layer %d input not extracted.", target_layer_ids[k]);
|
||||
}
|
||||
for (int32_t i = 0; i < n_chunk; ++i) {
|
||||
float * dst = features_buf.data() + (size_t) i * n_embd_enc + k * (size_t) n_embd_tgt;
|
||||
const float * src = layer + (size_t) (i_batch_beg[seq_id] + offset + i) * n_embd_tgt;
|
||||
std::memcpy(dst, src, (size_t) n_embd_tgt * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// fuse extracted features through DFlash encoder
|
||||
llama_batch enc_batch = {
|
||||
/*.n_tokens =*/ n_chunk,
|
||||
/*.token =*/ nullptr,
|
||||
/*.embd =*/ features_buf.data(),
|
||||
/*.pos =*/ nullptr,
|
||||
/*.n_seq_id =*/ nullptr,
|
||||
/*.seq_id =*/ nullptr,
|
||||
/*.logits =*/ nullptr,
|
||||
};
|
||||
|
||||
int32_t rc = llama_encode(ctx_dft, enc_batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
|
||||
__func__, rc, (int) n_chunk, (int) offset);
|
||||
return false;
|
||||
}
|
||||
|
||||
const float * inp_g = llama_get_embeddings_nextn(ctx_dft);
|
||||
GGML_ASSERT(inp_g && "DFlash encoder produced no output.");
|
||||
|
||||
// inject the DFlash decoder K/V cache at the tokens' target positions
|
||||
batch_inject.n_tokens = n_chunk;
|
||||
std::memcpy(batch_inject.embd, inp_g, (size_t) n_chunk * n_embd_dec * sizeof(float));
|
||||
|
||||
for (int32_t i = 0; i < n_chunk; ++i) {
|
||||
batch_inject.pos[i] = batch_in.pos[i_batch_beg[seq_id] + offset + i];
|
||||
batch_inject.n_seq_id[i] = 1;
|
||||
batch_inject.seq_id[i][0] = seq_id;
|
||||
batch_inject.logits[i] = false;
|
||||
}
|
||||
rc = llama_decode(ctx_dft, batch_inject);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
|
||||
__func__, rc, (int) n_chunk, (int) offset);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void draft(common_speculative_draft_params_vec & dparams) override {
|
||||
auto & ctx_dft = params.ctx_dft;
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
// build one batch holding every drafting sequence's noise block into a single decode)
|
||||
// record where each block starts and its size
|
||||
std::vector<int32_t> i_block_beg(n_seq, -1);
|
||||
std::vector<int32_t> n_block (n_seq, 0);
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
auto & dp = dparams[seq_id];
|
||||
if (!dp.drafting) {
|
||||
continue;
|
||||
}
|
||||
|
||||
common_sampler_reset(smpls[seq_id].get());
|
||||
|
||||
const int32_t n = (int32_t) dp.n_past;
|
||||
|
||||
int32_t n_draft = params.n_max;
|
||||
if (dp.n_max > 0) {
|
||||
n_draft = std::min(n_draft, dp.n_max);
|
||||
}
|
||||
|
||||
const int32_t n_block_tokens = n_draft + 1; // id_last + n_draft * <mask>
|
||||
i_block_beg[seq_id] = batch.n_tokens;
|
||||
n_block [seq_id] = n_block_tokens;
|
||||
for (int32_t i = 0; i < n_block_tokens; ++i) {
|
||||
common_batch_add(batch, i == 0 ? dp.id_last : mask_token_id, n + i, { seq_id }, true);
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// decode all sequence's noise block in a single batch
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
|
||||
return;
|
||||
}
|
||||
|
||||
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
|
||||
if (i_block_beg[seq_id] < 0) {
|
||||
continue;
|
||||
}
|
||||
auto & dp = dparams[seq_id];
|
||||
|
||||
const int32_t beg = i_block_beg[seq_id];
|
||||
const int32_t n_block_tokens = n_block[seq_id];
|
||||
|
||||
auto * smpl = smpls[seq_id].get();
|
||||
|
||||
auto & result = *dp.result;
|
||||
|
||||
// greedily read the predicted block at this sequence's noise positions 1..n_block_tokens-1
|
||||
for (int32_t i = 1; i < n_block_tokens; ++i) {
|
||||
common_sampler_sample(smpl, ctx_dft, beg + i, true);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i - 1, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
if (cur_p->data[0].p < params.p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
result.push_back(id);
|
||||
}
|
||||
|
||||
if (result.size() < (size_t) params.n_min) {
|
||||
result.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
|
||||
// noop
|
||||
}
|
||||
|
||||
bool need_embd() const override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
|
||||
|
||||
@@ -942,9 +1247,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
"MTP input row width must match the target h_nextn width");
|
||||
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
|
||||
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'draft-mtp'\n");
|
||||
SPC_TRC("- n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
|
||||
SPC_TRC("- gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n",
|
||||
this->params.n_gpu_layers,
|
||||
ggml_type_name(this->params.cache_type_k),
|
||||
ggml_type_name(this->params.cache_type_v),
|
||||
@@ -975,7 +1280,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
|
||||
|
||||
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
|
||||
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
|
||||
SPC_WRN("backend offload failed for seq_id=%d; using CPU sampler\n", (int) seq_id);
|
||||
llama_sampler_free(chain);
|
||||
chain = nullptr;
|
||||
}
|
||||
@@ -1038,11 +1343,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
|
||||
|
||||
if (pos_max < N - 1 && !is_mem_shared) {
|
||||
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
|
||||
SPC_WRN("ctx_dft pos_max=%d < N-1=%d - "
|
||||
"process() hook may not have run on every prefill ubatch "
|
||||
"(need_embd / logits=1 on every prompt position?). "
|
||||
"Drafts may degrade.\n",
|
||||
__func__, (int) pos_max, N - 1);
|
||||
(int) pos_max, N - 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1128,8 +1433,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
const int32_t rc = llama_decode(ctx_dft, batch);
|
||||
if (rc != 0) {
|
||||
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
|
||||
__func__, head, (int) rc, (int) batch_in.pos[0]);
|
||||
SPC_ERR("llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
|
||||
head, (int) rc, (int) batch_in.pos[0]);
|
||||
ok = false;
|
||||
break;
|
||||
}
|
||||
@@ -1217,7 +1522,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
|
||||
int ret = llama_decode(ctx_dft, batch);
|
||||
if (ret != 0) {
|
||||
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
|
||||
SPC_ERR("llama_decode[%d] returned %d\n", i, ret);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -1239,7 +1544,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
SPC_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
|
||||
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
@@ -1353,8 +1658,8 @@ struct common_speculative_impl_ngram_simple : public common_speculative_impl {
|
||||
, params(params.ngram_simple)
|
||||
, config(config)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__);
|
||||
LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'ngram-simple'\n");
|
||||
SPC_TRC("- size_n=%d, size_m=%d, min_hits=%d\n",
|
||||
this->params.size_n, this->params.size_m, this->params.min_hits);
|
||||
}
|
||||
|
||||
@@ -1403,8 +1708,8 @@ struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
|
||||
this->config.push_back(config);
|
||||
}
|
||||
|
||||
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str());
|
||||
LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__,
|
||||
SPC_TRC("adding speculative implementation '%s'\n", common_speculative_type_to_str(this->type).c_str());
|
||||
SPC_TRC("- size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n",
|
||||
config.size_key, config.size_value, config.key_only, config.min_hits);
|
||||
}
|
||||
|
||||
@@ -1478,15 +1783,15 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
|
||||
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
|
||||
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__);
|
||||
LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'ngram-mod'\n");
|
||||
SPC_TRC("- n_match=%d, n_max=%d, n_min=%d\n",
|
||||
this->params.n_match, this->params.n_max, this->params.n_min);
|
||||
LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__,
|
||||
SPC_TRC("- mod size=%zu (%.3f MB)\n",
|
||||
mod.size(), (float)(mod.size_bytes())/1024/1024);
|
||||
|
||||
if (this->params.n_match < 16) {
|
||||
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
|
||||
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match);
|
||||
SPC_WRN("ngram_mod n_match=%d is too small - poor quality is possible, "
|
||||
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", this->params.n_match);
|
||||
}
|
||||
|
||||
sinfos.resize(n_seq);
|
||||
@@ -1510,11 +1815,11 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
sinfo.i_last = prompt.size() - n;
|
||||
|
||||
const double f = (double)mod.get_used() / (double)mod.size();
|
||||
LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f);
|
||||
SPC_TRC("ngram_mod occupancy = %zu/%zu (%.2f)\n", mod.get_used(), mod.size(), f);
|
||||
|
||||
constexpr double f_thold = 0.25;
|
||||
if (f > f_thold) {
|
||||
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
|
||||
SPC_WRN("ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", f, f_thold);
|
||||
|
||||
mod.reset();
|
||||
}
|
||||
@@ -1608,7 +1913,7 @@ struct common_speculative_impl_ngram_mod : public common_speculative_impl {
|
||||
sinfo.n_low++;
|
||||
if (sinfo.n_low >= 5) {
|
||||
if (verbose) {
|
||||
LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low);
|
||||
SPC_TRC("low acceptance streak (%d) - resetting ngram_mod\n", sinfo.n_low);
|
||||
}
|
||||
|
||||
mod.reset();
|
||||
@@ -1658,8 +1963,8 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
, save_dynamic(save_dynamic)
|
||||
, save_static(save_static)
|
||||
{
|
||||
LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__);
|
||||
LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__,
|
||||
SPC_TRC("%s", "adding speculative implementation 'ngram-cache'\n");
|
||||
SPC_TRC("- n_draft=%d, cache_static=%s, cache_dynamic=%s\n",
|
||||
n_draft,
|
||||
path_static.empty() ? "none" : path_static.c_str(),
|
||||
path_dynamic.empty() ? "none" : path_dynamic.c_str());
|
||||
@@ -1674,7 +1979,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
sinfo.ngram_cache_static = ngram_cache_static;
|
||||
}
|
||||
} catch (...) {
|
||||
LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
|
||||
SPC_ERR("failed to open static lookup cache: %s", path_static.c_str());
|
||||
GGML_ABORT("Couldn't read static lookup cache");
|
||||
}
|
||||
}
|
||||
@@ -1687,7 +1992,7 @@ struct common_speculative_impl_ngram_cache : public common_speculative_impl {
|
||||
sinfo.ngram_cache_dynamic = ngram_cache_dynamic;
|
||||
}
|
||||
} catch (...) {
|
||||
LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
|
||||
SPC_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
|
||||
GGML_ABORT("Couldn't read dynamic lookup cache");
|
||||
}
|
||||
}
|
||||
@@ -1836,6 +2141,7 @@ std::string common_speculative_type_to_str(common_speculative_type type) {
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp";
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH: return "draft-dflash";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
|
||||
@@ -1888,6 +2194,7 @@ int32_t common_speculative_n_max(const common_params_speculative * spec) {
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH:
|
||||
n_max = std::max(n_max, std::max(0, spec->draft.n_max));
|
||||
break;
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
|
||||
@@ -1925,6 +2232,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
|
||||
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
|
||||
bool has_draft_dflash = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH)) && params.draft.ctx_dft != nullptr;
|
||||
|
||||
|
||||
|
||||
@@ -1935,7 +2243,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD));
|
||||
|
||||
// when adding a new type - update here the logic above
|
||||
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9);
|
||||
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 10);
|
||||
|
||||
// this list here defines the priority of the speculators
|
||||
// the one with highest priority are listed first
|
||||
@@ -1965,6 +2273,9 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
if (has_draft_mtp) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
|
||||
}
|
||||
if (has_draft_dflash) {
|
||||
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, params));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
|
||||
@@ -1985,6 +2296,10 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
impls.push_back(std::make_unique<common_speculative_impl_draft_mtp>(config.params, n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH: {
|
||||
impls.push_back(std::make_unique<common_speculative_impl_draft_dflash>(config.params, n_seq));
|
||||
break;
|
||||
}
|
||||
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
|
||||
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
|
||||
|
||||
@@ -2034,7 +2349,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
|
||||
}
|
||||
|
||||
if (impls.empty()) {
|
||||
LOG_WRN("%s: no implementations specified for speculative decoding\n", __func__);
|
||||
SPC_TRC("%s", "no implementations specified for speculative decoding\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -2161,13 +2476,13 @@ void common_speculative_draft(common_speculative * spec) {
|
||||
|
||||
if (dp.n_max > 0) {
|
||||
if (!result.empty() && (int) result.size() > dp.n_max) {
|
||||
LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max);
|
||||
SPC_DBG("truncating draft to %d tokens\n", dp.n_max);
|
||||
result.resize(dp.n_max);
|
||||
}
|
||||
}
|
||||
|
||||
if (!result.empty()) {
|
||||
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
|
||||
SPC_DBG("called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n",
|
||||
common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(),
|
||||
impl.get()->n_call_draft, result.size());
|
||||
|
||||
@@ -2291,7 +2606,7 @@ void common_speculative_print_stats(const common_speculative * spec) {
|
||||
str_stats = ", #mean acc len = " + oss.str() + ", #acc rate/pos = (" + tmp.str() + ")";
|
||||
}
|
||||
|
||||
LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
|
||||
SPC_TRC("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
|
||||
common_speculative_type_to_str(impl->type).c_str(),
|
||||
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
|
||||
impl->n_gen_drafts,
|
||||
|
||||
@@ -50,6 +50,8 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"DeepseekV2ForCausalLM": "deepseek",
|
||||
"DeepseekV3ForCausalLM": "deepseek",
|
||||
"DeepseekV32ForCausalLM": "deepseek",
|
||||
"DFlashDraftModel": "qwen",
|
||||
"DeepseekV4ForCausalLM": "deepseek",
|
||||
"DistilBertForMaskedLM": "bert",
|
||||
"DistilBertForSequenceClassification": "bert",
|
||||
"DistilBertModel": "bert",
|
||||
|
||||
+14
-1
@@ -1273,7 +1273,7 @@ class TextModel(ModelBase):
|
||||
if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
|
||||
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
|
||||
logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
|
||||
if (n_experts := self.find_hparam(["num_local_experts", "num_experts"], optional=True)) is not None:
|
||||
if (n_experts := self.find_hparam(["num_local_experts", "num_experts", "n_routed_experts"], optional=True)) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
logger.info(f"gguf: expert count = {n_experts}")
|
||||
if (n_experts_used := self.find_hparam(["num_experts_per_tok", "num_experts_per_token", "top_k_experts"], optional=True)) is not None:
|
||||
@@ -1291,6 +1291,8 @@ class TextModel(ModelBase):
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif score_func == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
elif score_func == "sqrtsoftplus":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SQRTSOFTPLUS)
|
||||
else:
|
||||
raise ValueError(f"Unsupported expert score gating function value: {score_func}")
|
||||
logger.info(f"gguf: expert score gating function = {score_func}")
|
||||
@@ -2600,6 +2602,17 @@ class LazyTorchTensor(gguf.LazyBase):
|
||||
return cls._wrap_fn(func)(*args, **kwargs)
|
||||
|
||||
|
||||
if hasattr(torch, "float8_e8m0fnu"):
|
||||
_torch_float8_e8m0 = torch.float8_e8m0fnu
|
||||
LazyTorchTensor._dtype_map[_torch_float8_e8m0] = np.uint8
|
||||
LazyTorchTensor._dtype_byteswap_map[_torch_float8_e8m0] = np.uint8
|
||||
LazyTorchTensor._dtype_str_map["F8_E8M0"] = _torch_float8_e8m0
|
||||
else:
|
||||
# Older torch builds do not expose F8_E8M0. Keep the raw bytes so callers
|
||||
# that know the format can decode them explicitly.
|
||||
LazyTorchTensor._dtype_str_map["F8_E8M0"] = torch.uint8
|
||||
|
||||
|
||||
def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str:
|
||||
# TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
|
||||
# maybe we should fallback to text model's arch in that case, since not many models have both
|
||||
|
||||
+308
-1
@@ -1,15 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Any, Callable, Iterable, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import Tensor
|
||||
|
||||
from .base import MmprojModel, ModelBase, TextModel, gguf, logger
|
||||
from .base import LazyTorchTensor, MmprojModel, ModelBase, TextModel, gguf, logger
|
||||
|
||||
from .qwen import QwenModel
|
||||
|
||||
@@ -467,3 +470,307 @@ class DeepseekV32Model(DeepseekV2Model):
|
||||
self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"])
|
||||
self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"])
|
||||
self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"])
|
||||
|
||||
|
||||
@ModelBase.register("DeepseekV4ForCausalLM")
|
||||
class DeepseekV4Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK4
|
||||
_skipped_mtp_tensors = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
type(self)._skipped_mtp_tensors = 0
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
with open(self.dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
raw_hparams = json.load(f)
|
||||
for key, value in raw_hparams.items():
|
||||
self.hparams.setdefault(key, value)
|
||||
|
||||
self.block_count = self.hparams["num_hidden_layers"]
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
self._dsv4_fp8_dequantized: set[str] = set()
|
||||
self._dsv4_bf16_tensors: set[str] = set()
|
||||
self._dsv4_f32_tensors: set[str] = set()
|
||||
self._dsv4_mxfp4_generated = False
|
||||
self._collect_source_dtypes()
|
||||
|
||||
if type(self)._skipped_mtp_tensors:
|
||||
logger.info("Skipping %d DeepSeek-V4 MTP tensor(s) for conversion v0", type(self)._skipped_mtp_tensors)
|
||||
|
||||
# add a default chat template; if the model has a built-in template, it will be overridden later
|
||||
template_path = Path(__file__).parent.parent / "models" / "templates" / "deepseek-ai-DeepSeek-V4.jinja"
|
||||
if template_path.is_file():
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
self.gguf_writer.add_chat_template(f.read())
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
|
||||
name, _ = item
|
||||
if name.startswith("mtp."):
|
||||
cls._skipped_mtp_tensors += 1
|
||||
return None
|
||||
return super().filter_tensors(item)
|
||||
|
||||
@staticmethod
|
||||
def _float8_dtypes() -> tuple[torch.dtype, ...]:
|
||||
return tuple(
|
||||
dtype for dtype in (
|
||||
getattr(torch, "float8_e4m3fn", None),
|
||||
getattr(torch, "float8_e5m2", None),
|
||||
) if dtype is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _e8m0_to_float(scale: Tensor) -> Tensor:
|
||||
torch_float8_e8m0 = getattr(torch, "float8_e8m0fnu", None)
|
||||
if torch_float8_e8m0 is not None and scale.dtype == torch_float8_e8m0:
|
||||
return scale.float()
|
||||
|
||||
bits = scale.view(torch.uint8).float()
|
||||
return torch.exp2(bits - 127.0)
|
||||
|
||||
def _collect_source_dtypes(self) -> None:
|
||||
for name, gen in self.model_tensors.items():
|
||||
dtype = gen().dtype
|
||||
if dtype == torch.bfloat16:
|
||||
self._dsv4_bf16_tensors.add(name)
|
||||
elif dtype == torch.float32:
|
||||
self._dsv4_f32_tensors.add(name)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
|
||||
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
|
||||
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
|
||||
self.gguf_writer.add_swiglu_clamp_exp([hparams["swiglu_limit"]] * self.block_count)
|
||||
self.gguf_writer.add_swiglu_clamp_shexp([hparams["swiglu_limit"]] * self.block_count)
|
||||
|
||||
self.gguf_writer.add_indexer_head_count(hparams["index_n_heads"])
|
||||
self.gguf_writer.add_indexer_key_length(hparams["index_head_dim"])
|
||||
self.gguf_writer.add_indexer_top_k(hparams["index_topk"])
|
||||
|
||||
self.gguf_writer.add_attention_output_group_count(hparams["o_groups"])
|
||||
self.gguf_writer.add_attention_output_lora_rank(hparams["o_lora_rank"])
|
||||
self.gguf_writer.add_attention_compress_ratios(hparams["compress_ratios"])
|
||||
self.gguf_writer.add_attention_compress_rope_freq_base(hparams["compress_rope_theta"])
|
||||
self.gguf_writer.add_hyper_connection_count(hparams["hc_mult"])
|
||||
self.gguf_writer.add_hyper_connection_sinkhorn_iterations(hparams["hc_sinkhorn_iters"])
|
||||
self.gguf_writer.add_hyper_connection_epsilon(hparams["hc_eps"])
|
||||
self.gguf_writer.add_hash_layer_count(hparams["num_hash_layers"])
|
||||
|
||||
def dequant_model(self):
|
||||
fp8_dtypes = self._float8_dtypes()
|
||||
tensors_to_remove: list[str] = []
|
||||
|
||||
def dequant_fp8_weight(weight: Tensor, scale: Tensor) -> Tensor:
|
||||
out_features, in_features = weight.shape
|
||||
scale_f = self._e8m0_to_float(scale)
|
||||
scale_f = scale_f.repeat_interleave(128, 0)[:out_features]
|
||||
scale_f = scale_f.repeat_interleave(128, 1)[:, :in_features]
|
||||
return weight.float() * scale_f
|
||||
|
||||
for name in list(self.model_tensors.keys()):
|
||||
if not name.endswith(".scale"):
|
||||
continue
|
||||
weight_name = name.removesuffix(".scale") + ".weight"
|
||||
if weight_name not in self.model_tensors:
|
||||
continue
|
||||
|
||||
weight = self.model_tensors[weight_name]
|
||||
scale = self.model_tensors[name]
|
||||
if weight().dtype not in fp8_dtypes:
|
||||
continue
|
||||
|
||||
self.model_tensors[weight_name] = lambda w=weight, s=scale: dequant_fp8_weight(w(), s())
|
||||
self._dsv4_fp8_dequantized.add(weight_name)
|
||||
tensors_to_remove.append(name)
|
||||
|
||||
for name in tensors_to_remove:
|
||||
del self.model_tensors[name]
|
||||
|
||||
@staticmethod
|
||||
def _pack_mxfp4_blocks(weight: Tensor, scale: Tensor) -> np.ndarray:
|
||||
packed = weight.contiguous().view(torch.uint8)
|
||||
scale_u8 = scale.contiguous().view(torch.uint8)
|
||||
|
||||
out_features, packed_cols = packed.shape
|
||||
logical_cols = packed_cols * 2
|
||||
if logical_cols % 32 != 0:
|
||||
raise ValueError(f"MXFP4 source row has {logical_cols} values, expected a multiple of 32")
|
||||
|
||||
n_blocks = logical_cols // 32
|
||||
if tuple(scale_u8.shape) != (out_features, n_blocks):
|
||||
raise ValueError(f"MXFP4 scale shape {tuple(scale_u8.shape)} does not match {(out_features, n_blocks)}")
|
||||
|
||||
src = packed.reshape(out_features, n_blocks, 16)
|
||||
low = src & 0x0F
|
||||
high = (src >> 4) & 0x0F
|
||||
|
||||
# The safetensors bytes store adjacent values as low/high nibbles.
|
||||
# ggml MXFP4 blocks store values 0..15 in low nibbles and 16..31 in high nibbles.
|
||||
vals = torch.stack((low, high), dim=-1).reshape(out_features, n_blocks, 32)
|
||||
qs = vals[:, :, :16] | (vals[:, :, 16:] << 4)
|
||||
raw = torch.cat((scale_u8.unsqueeze(-1), qs.to(torch.uint8)), dim=-1)
|
||||
return raw.reshape(out_features, n_blocks * 17).cpu().numpy()
|
||||
|
||||
def _write_mxfp4_expert_tensor(self, bid: int, proj: str, tensor_key: gguf.MODEL_TENSOR) -> list[str]:
|
||||
n_experts = self.hparams["n_routed_experts"]
|
||||
data: np.ndarray | None = None
|
||||
consumed: list[str] = []
|
||||
|
||||
for eid in range(n_experts):
|
||||
weight_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.weight"
|
||||
scale_name = f"layers.{bid}.ffn.experts.{eid}.{proj}.scale"
|
||||
if weight_name not in self.model_tensors or scale_name not in self.model_tensors:
|
||||
raise KeyError(f"Missing routed expert tensors for {weight_name}")
|
||||
|
||||
weight = LazyTorchTensor.to_eager(self.model_tensors[weight_name]())
|
||||
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
|
||||
packed = self._pack_mxfp4_blocks(weight, scale)
|
||||
if data is None:
|
||||
data = np.empty((n_experts, *packed.shape), dtype=packed.dtype)
|
||||
data[eid] = packed
|
||||
consumed.extend((weight_name, scale_name))
|
||||
|
||||
assert data is not None
|
||||
new_name = self.format_tensor_name(tensor_key, bid)
|
||||
shape = gguf.quant_shape_from_byte_shape(data.shape, gguf.GGMLQuantizationType.MXFP4)
|
||||
logger.info(f"{new_name}: repacked routed experts to MXFP4, shape = {{{', '.join(str(n) for n in reversed(shape))}}}")
|
||||
self.gguf_writer.add_tensor(new_name, data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
|
||||
|
||||
return consumed
|
||||
|
||||
def _write_hash_routing_tensors(self) -> list[str]:
|
||||
consumed: list[str] = []
|
||||
|
||||
for bid in range(self.hparams["num_hash_layers"]):
|
||||
name = f"layers.{bid}.ffn.gate.tid2eid"
|
||||
if name not in self.model_tensors:
|
||||
raise KeyError(f"Missing hash routing tensor {name}")
|
||||
|
||||
data_torch = LazyTorchTensor.to_eager(self.model_tensors[name]())
|
||||
data = data_torch.to(torch.int32).cpu().numpy()
|
||||
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_TID2EID, bid, ".weight")
|
||||
logger.info(f"{new_name}: converted hash routing table to I32, shape = {{{', '.join(str(n) for n in reversed(data.shape))}}}")
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
consumed.append(name)
|
||||
|
||||
return consumed
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
if self._dsv4_mxfp4_generated:
|
||||
return ()
|
||||
|
||||
consumed: list[str] = self._write_hash_routing_tensors()
|
||||
for bid in range(self.block_count):
|
||||
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w1", gguf.MODEL_TENSOR.FFN_GATE_EXP))
|
||||
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP))
|
||||
consumed.extend(self._write_mxfp4_expert_tensor(bid, "w3", gguf.MODEL_TENSOR.FFN_UP_EXP))
|
||||
|
||||
for name in consumed:
|
||||
del self.model_tensors[name]
|
||||
|
||||
self._dsv4_mxfp4_generated = True
|
||||
return ()
|
||||
|
||||
def _format_dsv4_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> str:
|
||||
return self.format_tensor_name(key, bid, suffix)
|
||||
|
||||
def _map_dsv4_tensor_name(self, name: str, bid: int | None) -> tuple[gguf.MODEL_TENSOR, str]:
|
||||
root_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
|
||||
"embed.weight": (gguf.MODEL_TENSOR.TOKEN_EMBD, ".weight"),
|
||||
"norm.weight": (gguf.MODEL_TENSOR.OUTPUT_NORM, ".weight"),
|
||||
"head.weight": (gguf.MODEL_TENSOR.OUTPUT, ".weight"),
|
||||
"hc_head_fn": (gguf.MODEL_TENSOR.HC_HEAD_FN, ".weight"),
|
||||
"hc_head_base": (gguf.MODEL_TENSOR.HC_HEAD_BASE, ".weight"),
|
||||
"hc_head_scale": (gguf.MODEL_TENSOR.HC_HEAD_SCALE, ".weight"),
|
||||
}
|
||||
if name in root_map:
|
||||
return root_map[name]
|
||||
|
||||
match = re.match(r"layers\.(\d+)\.(.+)$", name)
|
||||
if match is None:
|
||||
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
|
||||
|
||||
layer = int(match.group(1))
|
||||
if bid != layer:
|
||||
raise ValueError(f"Tensor {name!r} parsed bid {bid} but layer name has {layer}")
|
||||
|
||||
layer_map: dict[str, tuple[gguf.MODEL_TENSOR, str]] = {
|
||||
"hc_attn_fn": (gguf.MODEL_TENSOR.HC_ATTN_FN, ".weight"),
|
||||
"hc_attn_base": (gguf.MODEL_TENSOR.HC_ATTN_BASE, ".weight"),
|
||||
"hc_attn_scale": (gguf.MODEL_TENSOR.HC_ATTN_SCALE, ".weight"),
|
||||
"hc_ffn_fn": (gguf.MODEL_TENSOR.HC_FFN_FN, ".weight"),
|
||||
"hc_ffn_base": (gguf.MODEL_TENSOR.HC_FFN_BASE, ".weight"),
|
||||
"hc_ffn_scale": (gguf.MODEL_TENSOR.HC_FFN_SCALE, ".weight"),
|
||||
"attn.attn_sink": (gguf.MODEL_TENSOR.ATTN_SINKS, ".weight"),
|
||||
"attn.wq_a.weight": (gguf.MODEL_TENSOR.ATTN_Q_A, ".weight"),
|
||||
"attn.wq_b.weight": (gguf.MODEL_TENSOR.ATTN_Q_B, ".weight"),
|
||||
"attn.q_norm.weight": (gguf.MODEL_TENSOR.ATTN_Q_A_NORM, ".weight"),
|
||||
"attn.wkv.weight": (gguf.MODEL_TENSOR.ATTN_KV, ".weight"),
|
||||
"attn.kv_norm.weight": (gguf.MODEL_TENSOR.ATTN_KV_NORM, ".weight"),
|
||||
"attn.wo_a.weight": (gguf.MODEL_TENSOR.ATTN_OUT_A, ".weight"),
|
||||
"attn.wo_b.weight": (gguf.MODEL_TENSOR.ATTN_OUT_B, ".weight"),
|
||||
"attn.compressor.ape": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_APE, ".weight"),
|
||||
"attn.compressor.wkv.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WKV, ".weight"),
|
||||
"attn.compressor.wgate.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_WGATE, ".weight"),
|
||||
"attn.compressor.norm.weight": (gguf.MODEL_TENSOR.ATTN_COMPRESSOR_NORM, ".weight"),
|
||||
"attn.indexer.wq_b.weight": (gguf.MODEL_TENSOR.INDEXER_ATTN_Q_B, ".weight"),
|
||||
"attn.indexer.weights_proj.weight": (gguf.MODEL_TENSOR.INDEXER_PROJ, ".weight"),
|
||||
"attn.indexer.compressor.ape": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_APE, ".weight"),
|
||||
"attn.indexer.compressor.wkv.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WKV, ".weight"),
|
||||
"attn.indexer.compressor.wgate.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_WGATE, ".weight"),
|
||||
"attn.indexer.compressor.norm.weight": (gguf.MODEL_TENSOR.INDEXER_COMPRESSOR_NORM, ".weight"),
|
||||
"attn_norm.weight": (gguf.MODEL_TENSOR.ATTN_NORM, ".weight"),
|
||||
"ffn_norm.weight": (gguf.MODEL_TENSOR.FFN_NORM, ".weight"),
|
||||
"ffn.gate.weight": (gguf.MODEL_TENSOR.FFN_GATE_INP, ".weight"),
|
||||
"ffn.gate.bias": (gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, ".bias"),
|
||||
"ffn.gate.tid2eid": (gguf.MODEL_TENSOR.FFN_GATE_TID2EID, ".weight"),
|
||||
"ffn.shared_experts.w1.weight": (gguf.MODEL_TENSOR.FFN_GATE_SHEXP, ".weight"),
|
||||
"ffn.shared_experts.w2.weight": (gguf.MODEL_TENSOR.FFN_DOWN_SHEXP, ".weight"),
|
||||
"ffn.shared_experts.w3.weight": (gguf.MODEL_TENSOR.FFN_UP_SHEXP, ".weight"),
|
||||
}
|
||||
|
||||
tensor_name = match.group(2)
|
||||
if tensor_name in layer_map:
|
||||
return layer_map[tensor_name]
|
||||
|
||||
if re.match(r"ffn\.experts\.\d+\.w[123]\.(weight|scale)$", tensor_name):
|
||||
return gguf.MODEL_TENSOR.FFN_GATE_EXP, ".weight"
|
||||
|
||||
raise ValueError(f"Unsupported DeepSeek-V4 tensor {name!r}")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if re.match(r"layers\.\d+\.ffn\.experts\.\d+\.w[123]\.(weight|scale)$", name):
|
||||
return []
|
||||
|
||||
tensor_key, suffix = self._map_dsv4_tensor_name(name, bid)
|
||||
if tensor_key == gguf.MODEL_TENSOR.FFN_GATE_TID2EID:
|
||||
return []
|
||||
|
||||
return [(self._format_dsv4_tensor_name(tensor_key, bid, suffix), data_torch)]
|
||||
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
del new_name, bid # unused
|
||||
|
||||
if name in self._dsv4_fp8_dequantized and n_dims >= 2:
|
||||
return gguf.GGMLQuantizationType.Q8_0
|
||||
if name in self._dsv4_f32_tensors:
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
if name in self._dsv4_bf16_tensors and n_dims >= 2:
|
||||
return gguf.GGMLQuantizationType.BF16
|
||||
|
||||
return False
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
self._is_mxfp4 = True
|
||||
self.ftype = gguf.LlamaFileType.MOSTLY_MXFP4_MOE
|
||||
|
||||
+3
-3
@@ -73,7 +73,7 @@ class LlamaModel(TextModel):
|
||||
target_num_layers = target_config["num_hidden_layers"]
|
||||
target_layers = [2, target_num_layers // 2, target_num_layers - 3]
|
||||
logger.info(f"EAGLE-3: target_layers = {target_layers} (target model has {target_num_layers} layers)")
|
||||
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", target_layers)
|
||||
self.gguf_writer.add_target_layers(target_layers)
|
||||
|
||||
# target_hidden_size: prefer eagle3 config, fallback to target config
|
||||
if eagle3_raw_config.get("target_hidden_size") is not None:
|
||||
@@ -83,12 +83,12 @@ class LlamaModel(TextModel):
|
||||
target_hidden_size = target_config["hidden_size"]
|
||||
src = "target model config"
|
||||
logger.info(f"EAGLE-3: target_hidden_size = {target_hidden_size} (from {src})")
|
||||
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size)
|
||||
self.gguf_writer.add_target_hidden_size(target_hidden_size)
|
||||
|
||||
# norm_before_residual (RedHat-style eagle3 specific)
|
||||
norm_before_residual = eagle3_raw_config.get("norm_before_residual", False)
|
||||
logger.info(f"EAGLE-3: norm_before_residual = {norm_before_residual}")
|
||||
self.gguf_writer.add_bool(f"{self.gguf_writer.arch}.norm_before_residual", norm_before_residual)
|
||||
self.gguf_writer.add_norm_before_residual(norm_before_residual)
|
||||
|
||||
def set_vocab(self):
|
||||
# eagle3: use tokenizer from target model if provided
|
||||
|
||||
@@ -625,3 +625,51 @@ class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReor
|
||||
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
|
||||
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN35MOE
|
||||
|
||||
|
||||
@ModelBase.register("DFlashDraftModel")
|
||||
class DFlashModel(Qwen3Model):
|
||||
model_arch = gguf.MODEL_ARCH.DFLASH
|
||||
|
||||
def set_vocab(self):
|
||||
if self.target_model_dir is None:
|
||||
raise ValueError(
|
||||
"DFlash draft model requires --target-model-dir to be specified. "
|
||||
"Please provide the path to the target model directory containing the tokenizer."
|
||||
)
|
||||
logger.info(f"DFlash: Using tokenizer from target model: {self.target_model_dir}")
|
||||
original_dir = self.dir_model
|
||||
self.dir_model = self.target_model_dir
|
||||
super().set_vocab()
|
||||
self.dir_model = original_dir
|
||||
|
||||
mask_token_id = self.hparams.get("dflash_config", {}).get("mask_token_id")
|
||||
if mask_token_id is not None:
|
||||
self.gguf_writer.add_mask_token_id(mask_token_id)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
block_size = self.hparams.get("block_size", 16)
|
||||
self.gguf_writer.add_block_size(block_size)
|
||||
dflash_config = self.hparams.get("dflash_config", {})
|
||||
|
||||
target_layer_ids = dflash_config.get("target_layer_ids", [])
|
||||
if target_layer_ids:
|
||||
extract_layer_ids = [i + 1 for i in target_layer_ids]
|
||||
self.gguf_writer.add_target_layers(extract_layer_ids)
|
||||
|
||||
use_sliding_window = self.hparams.get("use_sliding_window", False)
|
||||
sliding_window = self.hparams.get("sliding_window")
|
||||
layer_types = self.hparams.get("layer_types")
|
||||
if use_sliding_window and sliding_window and layer_types:
|
||||
is_swa = [lt == "sliding_attention" for lt in layer_types]
|
||||
self.gguf_writer.add_sliding_window(sliding_window)
|
||||
self.gguf_writer.add_sliding_window_pattern(is_swa)
|
||||
|
||||
@classmethod
|
||||
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
|
||||
name, gen = item
|
||||
if not name.startswith("model."):
|
||||
name = "model." + name
|
||||
return super().filter_tensors((name, gen))
|
||||
|
||||
+51
-39
@@ -1,16 +1,26 @@
|
||||
# llama.cpp for OpenCL
|
||||
|
||||
- [Background](#background)
|
||||
- [OS](#os)
|
||||
- [Hardware](#hardware)
|
||||
- [DataType Supports](#datatype-supports)
|
||||
- [Model Preparation](#model-preparation)
|
||||
- [CMake Options](#cmake-options)
|
||||
- [Android](#android)
|
||||
- [Windows 11 Arm64](#windows-11-arm64)
|
||||
- [Linux](#Linux)
|
||||
- [Known Issue](#known-issues)
|
||||
- [TODO](#todo)
|
||||
- [llama.cpp for OpenCL](#llamacpp-for-opencl)
|
||||
- [Background](#background)
|
||||
- [Llama.cpp + OpenCL](#llamacpp--opencl)
|
||||
- [OS](#os)
|
||||
- [Hardware](#hardware)
|
||||
- [Adreno GPU](#adreno-gpu)
|
||||
- [DataType Supports](#datatype-supports)
|
||||
- [Model Preparation](#model-preparation)
|
||||
- [Binary Kernel Library](#binary-kernel-library)
|
||||
- [CMake Options](#cmake-options)
|
||||
- [Android](#android)
|
||||
- [I. Setup Environment](#i-setup-environment)
|
||||
- [II. Build llama.cpp](#ii-build-llamacpp)
|
||||
- [Windows 11 Arm64](#windows-11-arm64)
|
||||
- [I. Setup Environment](#i-setup-environment-1)
|
||||
- [II. Build llama.cpp](#ii-build-llamacpp-1)
|
||||
- [Linux](#linux)
|
||||
- [I. Setup Environment](#i-setup-environment-2)
|
||||
- [II. Build llama.cpp](#ii-build-llamacpp-2)
|
||||
- [Known Issues](#known-issues)
|
||||
- [TODO](#todo)
|
||||
|
||||
## Background
|
||||
|
||||
@@ -34,11 +44,13 @@ The llama.cpp OpenCL backend is designed to enable llama.cpp on **Qualcomm Adren
|
||||
|
||||
**Verified devices**
|
||||
|
||||
| Adreno GPU | Status |
|
||||
|:------------------------------------:|:-------:|
|
||||
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
|
||||
| Adreno 830 (Snapdragon 8 Elite) | Support |
|
||||
| Adreno X85 (Snapdragon X Elite) | Support |
|
||||
| Adreno GPU | Status |
|
||||
|:-------------------------------------:|:-------:|
|
||||
| Adreno 750 (Snapdragon 8 Gen 3) | Support |
|
||||
| Adreno 830 (Snapdragon 8 Elite) | Support |
|
||||
| Adreno 840 (Snapdragon 8 Elite Gen 5) | Support |
|
||||
| Adreno X1-85 (Snapdragon X Elite) | Support |
|
||||
| Adreno X2-90 (Snapdragon X2 Elite) | Support |
|
||||
|
||||
> A6x GPUs with a recent driver and compiler are supported; they are usually found in IoT platforms.
|
||||
However, A6x GPUs in phones are likely not supported due to the outdated driver and compiler.
|
||||
@@ -47,42 +59,43 @@ However, A6x GPUs in phones are likely not supported due to the outdated driver
|
||||
|
||||
| DataType | Status |
|
||||
|:----------------------:|:--------------------------:|
|
||||
| Q1_0 | Support |
|
||||
| Q4_0 | Support |
|
||||
| Q6_K | Support, but not optimized |
|
||||
| Q4_1 | Support |
|
||||
| Q5_0 | Support |
|
||||
| Q5_1 | Support |
|
||||
| Q8_0 | Support |
|
||||
| Q4_K | Support |
|
||||
| Q5_K | Support |
|
||||
| Q6_K | Support |
|
||||
| MXFP4 | Support |
|
||||
| IQ4_NL | Support |
|
||||
|
||||
## Model Preparation
|
||||
|
||||
You can refer to the general [llama-quantize tool](/tools/quantize/README.md) for steps to convert a model in Hugging Face safetensor format to GGUF with quantization.
|
||||
Since common quantizations are supported now, it is recommanded to download GGUF models directly from Huggingface.
|
||||
|
||||
Currently we support `Q4_0` quantization and have optimized for it. To achieve best performance on Adreno GPU, add `--pure` to `llama-quantize` (i.e., make all weights in `Q4_0`). For example,
|
||||
## Binary Kernel Library
|
||||
|
||||
```sh
|
||||
./llama-quantize --pure ggml-model-qwen2.5-3b-f16.gguf ggml-model-qwen-3b-Q4_0.gguf Q4_0
|
||||
```
|
||||
A prebuilt binary kernel library has been introduced for Adreno GPUs.
|
||||
It currently targets X2 GPUs (X2-90, X2-85 and X2-45) found in Snapdragon X2 SoC.
|
||||
The library currently contains kernels for MUL_MAT_ID with Q4_0, Q4_1, Q4_K, MXFP4.
|
||||
The library must be manually downloaded from https://softwarecenter.qualcomm.com/catalog/item/Adreno_Kernel_Library_GGML.
|
||||
|
||||
Since `Q6_K` is also supported, `Q4_0` quantization without `--pure` will also work. However, the performance will be worse compared to pure `Q4_0` quantization.
|
||||
To allow using the kernel library, add `-DGGML_OPENCL_USE_ADRENO_BIN_KERNELS=ON` when configuring with CMake.
|
||||
Then, extract `adreno-opencl-kernels.dll` from the zip file downloaded from the above URL and put it alongside the executables.
|
||||
If kernels compatible with the current GPU are found in the library, they will be loaded and used.
|
||||
|
||||
### `MXFP4` MoE Models
|
||||
|
||||
OpenAI gpt-oss models are MoE models in `MXFP4`. The quantized model will be in `MXFP4_MOE`, a mixture of `MXFP4` and `Q8_0`.
|
||||
For this quantization, there is no need to specify `--pure`.
|
||||
For gpt-oss-20b model, you can directly [download](https://huggingface.co/ggml-org/gpt-oss-20b-GGUF) the quantized GGUF file in `MXFP4_MOE` from Hugging Face.
|
||||
|
||||
Although it is possible to quantize gpt-oss-20b model in pure `Q4_0` (all weights in `Q4_0`), it is not recommended since `MXFP4` has been optimized for MoE while `Q4_0` is not. In addition, accuracy should degrade with such pure `Q4_0` quantization.
|
||||
Hence, using the default `MXFP4_MOE` quantization (see the link above) is recommended for this model.
|
||||
|
||||
> Note that the `Q4_0` model found [here](https://huggingface.co/unsloth/gpt-oss-20b-GGUF/blob/main/gpt-oss-20b-Q4_0.gguf) is a mixture of `Q4_0`, `Q8_0` and `MXFP4` and gives better performance than `MXFP4_MOE` quantization.
|
||||
|
||||
## CMake Options
|
||||
|
||||
The OpenCL backend has the following CMake options that control the behavior of the backend.
|
||||
|
||||
| CMake options | Default value | Description |
|
||||
|:---------------------------------:|:--------------:|:------------------------------------------|
|
||||
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
|
||||
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
|
||||
| CMake options | Default value | Description |
|
||||
|:------------------------------------:|:--------------:|:------------------------------------------|
|
||||
| `GGML_OPENCL_EMBED_KERNELS` | `ON` | Embed OpenCL kernels into the executable. |
|
||||
| `GGML_OPENCL_USE_ADRENO_KERNELS` | `ON` | Use kernels optimized for Adreno. |
|
||||
| `GGML_OPENCL_USE_ADRENO_BIN_KERNELS` | `OFF` | Allow using binary kernel lib for Adreno. |
|
||||
|
||||
## Android
|
||||
|
||||
@@ -277,6 +290,5 @@ ninja
|
||||
|
||||
## TODO
|
||||
|
||||
- Optimization for Q6_K
|
||||
- Support and optimization for Q4_K
|
||||
- Improve flash attention
|
||||
- Improve OpenCL C kernels performance
|
||||
|
||||
@@ -237,8 +237,8 @@ chmod +x ubuntu-llamacpp-ov-install.sh
|
||||
# ============================================
|
||||
set -euo pipefail
|
||||
|
||||
OPENVINO_VERSION_MAJOR="2026.2"
|
||||
OPENVINO_VERSION_FULL="2026.2.0.21903.52ddc073857"
|
||||
OPENVINO_VERSION_MAJOR="2026.2.1"
|
||||
OPENVINO_VERSION_FULL="2026.2.1.21919.ede283a88e3"
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OPENVINO_INSTALL_DIR="/opt/intel/openvino_${OPENVINO_VERSION_MAJOR}"
|
||||
@@ -334,7 +334,7 @@ echo " ./build/ReleaseOV/bin/llama-cli -m model.gguf"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The script pins OpenVINO `2026.2` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release.
|
||||
> The script pins OpenVINO `2026.2.1` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -364,8 +364,8 @@ REM ============================================
|
||||
REM llama.cpp OpenVINO Build Script (Ninja)
|
||||
REM ============================================
|
||||
|
||||
set "OPENVINO_VERSION_MAJOR=2026.2"
|
||||
set "OPENVINO_VERSION_FULL=2026.2.0.21903.52ddc073857"
|
||||
set "OPENVINO_VERSION_MAJOR=2026.2.1"
|
||||
set "OPENVINO_VERSION_FULL=2026.2.1.21919.ede283a88e3"
|
||||
|
||||
set "SCRIPT_DIR=%~dp0"
|
||||
set "VCPKG_DIR=C:\vcpkg"
|
||||
@@ -547,7 +547,7 @@ endlocal
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The script pins OpenVINO `2026.2` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release. From any new shell, source the matching `setupvars` script via the junction — `call "C:\Intel\openvino\setupvars.bat"` from `cmd`, or `& "C:\Intel\openvino\setupvars.ps1"` from PowerShell. If `winget` cannot register Visual Studio Build Tools on first run, install them once manually and re-run the script from an elevated **Developer Command Prompt for VS 2022**.
|
||||
> The script pins OpenVINO `2026.2.1` via the `OPENVINO_VERSION_MAJOR` / `OPENVINO_VERSION_FULL` variables at the top — edit them to track a different release. From any new shell, source the matching `setupvars` script via the junction — `call "C:\Intel\openvino\setupvars.bat"` from `cmd`, or `& "C:\Intel\openvino\setupvars.ps1"` from PowerShell. If `winget` cannot register Visual Studio Build Tools on first run, install them once manually and re-run the script from an elevated **Developer Command Prompt for VS 2022**.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
+28
-1
@@ -52,6 +52,32 @@ Supported EAGLE-3 draft models include:
|
||||
|
||||
For the full and up-to-date list of supported models, see #18039.
|
||||
|
||||
### DFlash (`draft-dflash`)
|
||||
|
||||
DFlash produces an entire block of draft tokens in a single forward pass (block diffusion) and
|
||||
injects the target model's hidden states into the draft model's attention, instead of drafting one
|
||||
token at a time. This keeps the draft model small while making drafting GPU-friendly. Unlike EAGLE-3
|
||||
(a single-layer autoregressive draft), the DFlash draft uses several transformer layers but emits a
|
||||
whole block per draft step.
|
||||
|
||||
The draft is a small block-diffusion model trained for a specific target (for example
|
||||
`z-lab/Qwen3-4B-DFlash` for `Qwen/Qwen3-4B`). Convert it with `--target-model-dir` so it inherits the
|
||||
target's tokenizer and token embeddings:
|
||||
|
||||
```bash
|
||||
python convert_hf_to_gguf.py z-lab/Qwen3-4B-DFlash \
|
||||
--target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-DFlash.gguf
|
||||
|
||||
llama-server -m Qwen3-4B.gguf -md Qwen3-4B-DFlash.gguf \
|
||||
--spec-type draft-dflash --spec-draft-n-max 15 -fa on --jinja
|
||||
```
|
||||
|
||||
`--spec-draft-n-max` is clamped to the draft model's trained block size.
|
||||
|
||||
See:
|
||||
|
||||
- #22105
|
||||
|
||||
### n-gram Cache (`ngram-cache`)
|
||||
|
||||
An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
|
||||
@@ -147,7 +173,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
|
||||
### General Speculative Parameters
|
||||
|
||||
```
|
||||
--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
--spec-type [none|draft-simple|draft-eagle3|draft-dflash|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
|
||||
comma-separated list of types of speculative decoding to use
|
||||
(default: none)
|
||||
(env: LLAMA_ARG_SPEC_TYPE)
|
||||
@@ -287,6 +313,7 @@ Specifies a comma-separated list of speculative decoding types to use.
|
||||
| `none` | No speculative decoding (default) |
|
||||
| `draft-simple` | Use a simple draft model for speculation |
|
||||
| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states |
|
||||
| `draft-dflash` | Use a DFlash block-diffusion draft model that emits a block per step |
|
||||
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
|
||||
| `ngram-cache` | Use n-gram cache lookup |
|
||||
| `ngram-simple` | Use simple n-gram pattern matching |
|
||||
|
||||
+1
-1
@@ -5,7 +5,7 @@ project("ggml" C CXX ASM)
|
||||
### GGML Version
|
||||
set(GGML_VERSION_MAJOR 0)
|
||||
set(GGML_VERSION_MINOR 15)
|
||||
set(GGML_VERSION_PATCH 2)
|
||||
set(GGML_VERSION_PATCH 3)
|
||||
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
|
||||
|
||||
@@ -1111,11 +1111,12 @@ GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
|
||||
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
|
||||
GGML_TABLE_END()
|
||||
|
||||
// e2m1 values (doubled)
|
||||
// e2m1 values (doubled), shared by MXFP4 and NVFP4
|
||||
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
|
||||
GGML_TABLE_BEGIN(int8_t, kvalues_fp4, 16)
|
||||
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
|
||||
GGML_TABLE_END()
|
||||
#define kvalues_mxfp4 kvalues_fp4
|
||||
|
||||
#define NGRID_IQ1S 2048
|
||||
#define IQ1S_DELTA 0.125f
|
||||
|
||||
@@ -82,7 +82,6 @@
|
||||
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
||||
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
||||
// quants.c
|
||||
#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
|
||||
// repack.cpp
|
||||
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
|
||||
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
|
||||
|
||||
@@ -934,7 +934,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
|
||||
#if defined __AVX2__
|
||||
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
|
||||
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||
const __m256i mone = _mm256_set1_epi16(1);
|
||||
|
||||
@@ -963,7 +963,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
|
||||
|
||||
#elif defined __AVX__
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
|
||||
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||
|
||||
__m256 accum = _mm256_setzero_ps();
|
||||
@@ -993,14 +993,152 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||
int sumi1 = 0;
|
||||
int sumi2 = 0;
|
||||
for (int j = 0; j < QK_MXFP4/2; ++j) {
|
||||
sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
|
||||
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
|
||||
sumi1 += y[ib].qs[j + 0] * kvalues_fp4[x[ib].qs[j] & 0xf];
|
||||
sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_fp4[x[ib].qs[j] >> 4];
|
||||
}
|
||||
sumf += d * (sumi1 + sumi2);
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
assert(n % QK_NVFP4 == 0);
|
||||
|
||||
const block_nvfp4 * GGML_RESTRICT x = vx;
|
||||
const block_q8_0 * GGML_RESTRICT y = vy;
|
||||
|
||||
const int nb = n / QK_NVFP4;
|
||||
int ib = 0;
|
||||
float sumf = 0;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
|
||||
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||
const __m256i mone = _mm256_set1_epi16(1);
|
||||
|
||||
__m256 accum = _mm256_setzero_ps();
|
||||
for(; ib < nb; ib++){
|
||||
|
||||
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
|
||||
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
|
||||
|
||||
const __m256i q8_01 = _mm256_loadu_si256((const __m256i *)y[2*ib + 0].qs);
|
||||
const __m256i q8_23 = _mm256_loadu_si256((const __m256i *)y[2*ib + 1].qs);
|
||||
|
||||
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
|
||||
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
|
||||
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
|
||||
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
|
||||
|
||||
//reordering
|
||||
const __m256i q4_01 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_01_lo,q4_01_hi), _mm_unpacklo_epi64(q4_01_lo,q4_01_hi));
|
||||
const __m256i q4_23 = MM256_SET_M128I(_mm_unpackhi_epi64(q4_23_lo,q4_23_hi),_mm_unpacklo_epi64(q4_23_lo,q4_23_hi));
|
||||
|
||||
const __m256i p01 = mul_add_epi8(q4_01,q8_01);
|
||||
const __m256i p_1 = _mm256_madd_epi16(p01, mone);
|
||||
|
||||
const __m256i p23 = mul_add_epi8(q4_23,q8_23);
|
||||
const __m256i p_2 = _mm256_madd_epi16(p23, mone);
|
||||
|
||||
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
|
||||
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
|
||||
|
||||
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
|
||||
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
|
||||
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
|
||||
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
|
||||
|
||||
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
|
||||
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
|
||||
|
||||
accum = _mm256_fmadd_ps(scales01, _mm256_cvtepi32_ps(p_1), accum);
|
||||
accum = _mm256_fmadd_ps(scales23, _mm256_cvtepi32_ps(p_2), accum);
|
||||
}
|
||||
sumf = hsum_float_8(accum);
|
||||
|
||||
#elif defined(__AVX__)
|
||||
|
||||
const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_fp4);
|
||||
const __m128i m4b = _mm_set1_epi8(0x0f);
|
||||
|
||||
__m256 accum = _mm256_setzero_ps();
|
||||
for(; ib < nb; ib++){
|
||||
|
||||
const __m128i q4bits_01 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 0));
|
||||
const __m128i q4bits_23 = _mm_loadu_si128((const __m128i *)(x[ib].qs + 16));
|
||||
|
||||
const __m128i q8_0 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 0));
|
||||
const __m128i q8_1 = _mm_loadu_si128((const __m128i *)(y[2*ib + 0].qs + 16));
|
||||
const __m128i q8_2 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 0));
|
||||
const __m128i q8_3 = _mm_loadu_si128((const __m128i *)(y[2*ib + 1].qs + 16));
|
||||
|
||||
const __m128i q4_01_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_01, m4b));
|
||||
const __m128i q4_01_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_01, 4), m4b));
|
||||
const __m128i q4_23_lo = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_23, m4b));
|
||||
const __m128i q4_23_hi = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_23, 4), m4b));
|
||||
|
||||
const __m128i q4_0 = _mm_unpacklo_epi64(q4_01_lo, q4_01_hi);
|
||||
const __m128i q4_1 = _mm_unpackhi_epi64(q4_01_lo, q4_01_hi);
|
||||
const __m128i q4_2 = _mm_unpacklo_epi64(q4_23_lo, q4_23_hi);
|
||||
const __m128i q4_3 = _mm_unpackhi_epi64(q4_23_lo, q4_23_hi);
|
||||
|
||||
const __m128i p0_i32 = mul_sum_i8_pairs(q4_0, q8_0);
|
||||
const __m128i p1_i32 = mul_sum_i8_pairs(q4_1, q8_1);
|
||||
const __m128i p2_i32 = mul_sum_i8_pairs(q4_2, q8_2);
|
||||
const __m128i p3_i32 = mul_sum_i8_pairs(q4_3, q8_3);
|
||||
|
||||
const __m128 p0 = _mm_cvtepi32_ps(p0_i32);
|
||||
const __m128 p1 = _mm_cvtepi32_ps(p1_i32);
|
||||
const __m128 p2 = _mm_cvtepi32_ps(p2_i32);
|
||||
const __m128 p3 = _mm_cvtepi32_ps(p3_i32);
|
||||
|
||||
const __m256 p01 = _mm256_set_m128(p1, p0);
|
||||
const __m256 p23 = _mm256_set_m128(p3, p2);
|
||||
|
||||
const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
|
||||
const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
|
||||
|
||||
const float s0 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[0]) * dy0;
|
||||
const float s1 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[1]) * dy0;
|
||||
const float s2 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[2]) * dy1;
|
||||
const float s3 = GGML_CPU_UE4M3_TO_FP32(x[ib].d[3]) * dy1;
|
||||
|
||||
const __m256 scales01 = _mm256_set_m128(_mm_set1_ps(s1), _mm_set1_ps(s0));
|
||||
const __m256 scales23 = _mm256_set_m128(_mm_set1_ps(s3), _mm_set1_ps(s2));
|
||||
|
||||
accum = _mm256_add_ps(accum, _mm256_mul_ps(p01, scales01));
|
||||
accum = _mm256_add_ps(accum, _mm256_mul_ps(p23, scales23));
|
||||
}
|
||||
sumf = hsum_float_8(accum);
|
||||
|
||||
#endif
|
||||
|
||||
for (;ib < nb; ++ib) {
|
||||
for (int s_idx = 0; s_idx < 4; ++s_idx) {
|
||||
const float d = GGML_CPU_UE4M3_TO_FP32(x[ib].d[s_idx]);
|
||||
const int q8_block = s_idx / 2;
|
||||
const int q8_off = (s_idx % 2) * QK_NVFP4_SUB;
|
||||
const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d);
|
||||
|
||||
int sumi_lo = 0, sumi_hi = 0;
|
||||
for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
|
||||
const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j];
|
||||
sumi_lo += y[2*ib + q8_block].qs[q8_off + j + 0] * kvalues_fp4[qv & 0xf];
|
||||
sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_fp4[qv >> 4];
|
||||
}
|
||||
|
||||
sumf += dy * d * (sumi_lo + sumi_hi);
|
||||
}
|
||||
}
|
||||
*s = sumf;
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
@@ -82,6 +82,9 @@ float ggml_table_f32_f16[1 << 16];
|
||||
// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h)
|
||||
float ggml_table_f32_e8m0_half[1 << 8];
|
||||
|
||||
// precomputed f32 table for ue4m3 (1 KB) (simd-mappings.h)
|
||||
float ggml_table_f32_ue4m3[1 << 8];
|
||||
|
||||
#if defined(__ARM_ARCH)
|
||||
struct ggml_arm_arch_features_type {
|
||||
int sve_cnt;
|
||||
@@ -3798,6 +3801,11 @@ void ggml_cpu_init(void) {
|
||||
ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i);
|
||||
}
|
||||
|
||||
// initialize UE4M3 table (256 entries)
|
||||
for (int i = 0; i < (1 << 8); ++i) {
|
||||
ggml_table_f32_ue4m3[i] = ggml_ue4m3_to_fp32(i);
|
||||
}
|
||||
|
||||
const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
|
||||
|
||||
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);
|
||||
|
||||
@@ -120,6 +120,10 @@ extern float ggml_table_f32_f16[1 << 16];
|
||||
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
|
||||
extern float ggml_table_f32_e8m0_half[1 << 8];
|
||||
|
||||
// precomputed f32 table for ue4m3 (1 KB)
|
||||
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
|
||||
extern float ggml_table_f32_ue4m3[1 << 8];
|
||||
|
||||
// Use lookup table for E8M0 on x86 (faster than bit manipulation)
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)]
|
||||
@@ -127,6 +131,13 @@ extern float ggml_table_f32_e8m0_half[1 << 8];
|
||||
#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x)
|
||||
#endif
|
||||
|
||||
// Use lookup table for UE4M3 on x86 (faster than bit manipulation)
|
||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
||||
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_table_f32_ue4m3[(uint8_t)(x)]
|
||||
#else
|
||||
#define GGML_CPU_UE4M3_TO_FP32(x) ggml_ue4m3_to_fp32(x)
|
||||
#endif
|
||||
|
||||
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
|
||||
// so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON.
|
||||
// This is also true for POWER9.
|
||||
|
||||
@@ -386,6 +386,46 @@ static void ggml_cpy_f32_iq4_nl_cuda(
|
||||
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
}
|
||||
|
||||
// check if a same-type copy reduces to a 2D strided copy (height rows of width
|
||||
// contiguous bytes), so it can use cudaMemcpy2DAsync instead of the scalar kernel
|
||||
static bool ggml_cuda_cpy_as_memcpy_2d(const ggml_tensor * src0, const ggml_tensor * src1,
|
||||
size_t & width, size_t & height, size_t & spitch, size_t & dpitch) {
|
||||
// require matching shape: a reshaped copy maps elements by flat order, which the
|
||||
// prefix walk below does not handle
|
||||
if (src0->type != src1->type || !ggml_are_same_shape(src0, src1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// grow the contiguous prefix block shared by both tensors
|
||||
size_t block_nb = ggml_element_size(src0);
|
||||
int d = 0;
|
||||
for (; d < GGML_MAX_DIMS; ++d) {
|
||||
if (src0->nb[d] != block_nb || src1->nb[d] != block_nb) {
|
||||
break;
|
||||
}
|
||||
block_nb *= src0->ne[d];
|
||||
}
|
||||
|
||||
// d == 0: nothing contiguous; d == GGML_MAX_DIMS: fully contiguous (handled by memcpy)
|
||||
if (d == 0 || d == GGML_MAX_DIMS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// dim d carries the rows; everything above it must be a single element
|
||||
for (int i = d + 1; i < GGML_MAX_DIMS; ++i) {
|
||||
if (src0->ne[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
width = block_nb;
|
||||
height = src0->ne[d];
|
||||
spitch = src0->nb[d];
|
||||
dpitch = src1->nb[d];
|
||||
|
||||
return spitch >= width && dpitch >= width;
|
||||
}
|
||||
|
||||
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne == ggml_nelements(src1));
|
||||
@@ -421,6 +461,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
|
||||
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
|
||||
|
||||
size_t mc_width = 0, mc_height = 0, mc_spitch = 0, mc_dpitch = 0;
|
||||
|
||||
if (src0->type == src1->type && contiguous_srcs) {
|
||||
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
||||
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
|
||||
@@ -431,6 +473,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
{
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
||||
}
|
||||
} else if (ggml_cuda_cpy_as_memcpy_2d(src0, src1, mc_width, mc_height, mc_spitch, mc_dpitch)) {
|
||||
CUDA_CHECK(cudaMemcpy2DAsync(src1_ddc, mc_dpitch, src0_ddc, mc_spitch,
|
||||
mc_width, mc_height, cudaMemcpyDeviceToDevice, main_stream));
|
||||
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
||||
if (can_be_transposed) {
|
||||
ggml_cpy_scalar_cuda<float, float, true>
|
||||
|
||||
@@ -664,7 +664,10 @@ constexpr __device__ dequantize_V_t get_dequantize_V() {
|
||||
template <int ncols1>
|
||||
__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
|
||||
static __global__ void flash_attn_mask_to_KV_max(
|
||||
const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
|
||||
const half2 * mask_ptr, int * KV_max_ptr, const int ne30, const int64_t s31, const int64_t s33) {
|
||||
const half2 * GGML_CUDA_RESTRICT mask = mask_ptr;
|
||||
int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr;
|
||||
|
||||
const int ne31 = gridDim.x;
|
||||
const int tid = threadIdx.x;
|
||||
const int sequence = blockIdx.y;
|
||||
@@ -1089,8 +1092,8 @@ void launch_fattn(
|
||||
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
|
||||
// multiple sequences of possibly different lengths.
|
||||
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
|
||||
const int s31 = mask->nb[1] / sizeof(half2);
|
||||
const int s33 = mask->nb[3] / sizeof(half2);
|
||||
const int64_t s31 = mask->nb[1] / sizeof(half2);
|
||||
const int64_t s33 = mask->nb[3] / sizeof(half2);
|
||||
|
||||
const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
|
||||
const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
|
||||
@@ -1099,8 +1102,9 @@ void launch_fattn(
|
||||
const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
|
||||
|
||||
KV_max.alloc(ne_KV_max);
|
||||
flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
|
||||
((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
|
||||
ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_KV_max, block_dim_KV_max, 0, main_stream);
|
||||
ggml_cuda_kernel_launch(flash_attn_mask_to_KV_max<ncols1>, launch_params,
|
||||
(const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
||||
@@ -2003,6 +2003,10 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
|
||||
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 2);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 2);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 2);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 32, 2);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
|
||||
extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
|
||||
|
||||
@@ -76,6 +76,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
@@ -144,6 +145,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64)
|
||||
@@ -219,6 +221,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64)
|
||||
@@ -296,6 +299,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64)
|
||||
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64)
|
||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64)
|
||||
@@ -1308,12 +1312,12 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio % 2 == 0) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (DV <= 256) {
|
||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -99,12 +99,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (DKQ <= 256) {
|
||||
if (use_gqa_opt && gqa_ratio > 1) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
if (use_gqa_opt && gqa_ratio > 1) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (DKQ <= 256) {
|
||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
|
||||
} else {
|
||||
GGML_ABORT("fatal error");
|
||||
|
||||
@@ -10,6 +10,7 @@ gated_delta_net_cuda(const float * q,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
float * state,
|
||||
int64_t H,
|
||||
int64_t n_tokens,
|
||||
int64_t n_seqs,
|
||||
@@ -25,6 +26,7 @@ gated_delta_net_cuda(const float * q,
|
||||
const uint3 neqk1_magic,
|
||||
const uint3 rq3_magic,
|
||||
float scale,
|
||||
int64_t state_slot_stride,
|
||||
int K) {
|
||||
const uint32_t h_idx = blockIdx.x;
|
||||
const uint32_t sequence = blockIdx.y;
|
||||
@@ -35,9 +37,7 @@ gated_delta_net_cuda(const float * q,
|
||||
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
|
||||
const uint32_t iq3 = fastdiv(sequence, rq3_magic);
|
||||
|
||||
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
||||
float * attn_data = dst;
|
||||
float * state = dst + attn_score_elems;
|
||||
|
||||
// input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v.
|
||||
// output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before.
|
||||
@@ -145,10 +145,9 @@ gated_delta_net_cuda(const float * q,
|
||||
if constexpr (keep_rs_t) {
|
||||
// snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back.
|
||||
// When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned.
|
||||
const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output
|
||||
const int target_slot = (int) n_tokens - 1 - t;
|
||||
if (target_slot >= 0 && target_slot < K) {
|
||||
float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset;
|
||||
float * curr_state = state + target_slot * state_slot_stride;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
@@ -171,13 +170,13 @@ template <bool KDA, bool keep_rs_t>
|
||||
static void launch_gated_delta_net(
|
||||
const float * q_d, const float * k_d, const float * v_d,
|
||||
const float * g_d, const float * b_d, const float * s_d,
|
||||
float * dst_d,
|
||||
float * dst_d, float * state_d,
|
||||
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
|
||||
int64_t sq1, int64_t sq2, int64_t sq3,
|
||||
int64_t sv1, int64_t sv2, int64_t sv3,
|
||||
int64_t sb1, int64_t sb2, int64_t sb3,
|
||||
int64_t neqk1, int64_t rq3,
|
||||
float scale, int K, cudaStream_t stream) {
|
||||
float scale, int64_t state_slot_stride, int K, cudaStream_t stream) {
|
||||
//TODO: Add chunked kernel for even faster pre-fill
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
const int num_warps = 4;
|
||||
@@ -187,34 +186,32 @@ static void launch_gated_delta_net(
|
||||
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
|
||||
const uint3 rq3_magic = init_fastdiv_values(rq3);
|
||||
|
||||
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream);
|
||||
switch (S_v) {
|
||||
case 16:
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
|
||||
break;
|
||||
case 32:
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
|
||||
break;
|
||||
case 64: {
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
|
||||
break;
|
||||
}
|
||||
case 128: {
|
||||
ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K);
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, state_slot_stride, K);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -223,7 +220,8 @@ static void launch_gated_delta_net(
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
static void ggml_cuda_op_gated_delta_net_impl(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, const ggml_cuda_gated_delta_net_fused_cache * cache) {
|
||||
ggml_tensor * src_q = dst->src[0];
|
||||
ggml_tensor * src_k = dst->src[1];
|
||||
ggml_tensor * src_v = dst->src[2];
|
||||
@@ -288,25 +286,42 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||
const int K = ggml_get_op_params_i32(dst, 0);
|
||||
const bool keep_rs = K > 1;
|
||||
|
||||
// recurrent state -> gdn_out tail (after attention scores), or the cache when fusing
|
||||
float * state_d = dst_d + S_v * H * n_tokens * n_seqs;
|
||||
int64_t state_slot_stride = S_v * S_v * H * n_seqs;
|
||||
if (cache != nullptr) {
|
||||
state_d = cache->data;
|
||||
state_slot_stride = cache->slot_stride;
|
||||
}
|
||||
|
||||
if (kda) {
|
||||
if (keep_rs) {
|
||||
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
|
||||
} else {
|
||||
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
|
||||
}
|
||||
} else {
|
||||
if (keep_rs) {
|
||||
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
|
||||
} else {
|
||||
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
|
||||
launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, state_d,
|
||||
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, K, stream);
|
||||
sb1, sb2, sb3, neqk1, rq3, scale, state_slot_stride, K, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
ggml_cuda_op_gated_delta_net_impl(ctx, dst, nullptr);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gated_delta_net_fused_cache(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_cuda_gated_delta_net_fused_cache cache) {
|
||||
ggml_cuda_op_gated_delta_net_impl(ctx, dst, &cache);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,14 @@
|
||||
#include "common.cuh"
|
||||
#include "ggml.h"
|
||||
|
||||
// fused-kernel recurrent-state output; strides in elements (per-seq stride is always D, set in-kernel)
|
||||
struct ggml_cuda_gated_delta_net_fused_cache {
|
||||
float * data; // rollback slot 0
|
||||
int64_t slot_stride; // between rollback slots (0 when K==1)
|
||||
};
|
||||
|
||||
void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
// same op, but writes the snapshot(s) into the cache instead of dst (see ggml_cuda_try_gdn_cache_fusion)
|
||||
void ggml_cuda_op_gated_delta_net_fused_cache(ggml_backend_cuda_context & ctx, ggml_tensor * dst,
|
||||
ggml_cuda_gated_delta_net_fused_cache cache);
|
||||
|
||||
@@ -78,26 +78,29 @@ static __global__ void k_get_rows_float(
|
||||
|
||||
template<typename grad_t, typename dst_t>
|
||||
static __global__ void k_get_rows_back_float(
|
||||
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
|
||||
const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst,
|
||||
const int64_t ncols, const int64_t nrows_grad, const int64_t nrows_dst) {
|
||||
const int col = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
|
||||
if (col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||
|
||||
float sum = 0.0f;
|
||||
|
||||
ggml_cuda_pdl_sync();
|
||||
for (int64_t i = 0; i < nrows_grad; ++i) {
|
||||
if (rows[i] != dst_row) {
|
||||
continue;
|
||||
}
|
||||
sum += grad[i*ncols + col];
|
||||
}
|
||||
|
||||
dst[dst_row*ncols + col] = sum;
|
||||
// grid.y is clamped to the CUDA grid limit, so stride over the destination rows
|
||||
for (int64_t dst_row = blockIdx.y; dst_row < nrows_dst; dst_row += gridDim.y) {
|
||||
float sum = 0.0f;
|
||||
|
||||
for (int64_t i = 0; i < nrows_grad; ++i) {
|
||||
if (rows[i] != dst_row) {
|
||||
continue;
|
||||
}
|
||||
sum += grad[i*ncols + col];
|
||||
}
|
||||
|
||||
dst[dst_row*ncols + col] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
template<int qk, int qr, dequantize_kernel_t dq, typename dst_t>
|
||||
@@ -302,7 +305,7 @@ void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
|
||||
|
||||
const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
|
||||
const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
|
||||
const dim3 block_nums(block_num_x, ne1, 1);
|
||||
const dim3 block_nums(block_num_x, MIN(ne1, (int64_t)UINT16_MAX), 1);
|
||||
|
||||
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
|
||||
k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10, ne1);
|
||||
}
|
||||
|
||||
@@ -3251,6 +3251,11 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static bool ggml_cuda_is_view_or_noop(const ggml_tensor * t) {
|
||||
return ggml_is_empty(t) || t->op == GGML_OP_RESHAPE || t->op == GGML_OP_TRANSPOSE ||
|
||||
t->op == GGML_OP_VIEW || t->op == GGML_OP_PERMUTE || t->op == GGML_OP_NONE;
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
|
||||
@@ -3260,7 +3265,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
if (ggml_cuda_is_view_or_noop(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -3403,6 +3408,70 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
|
||||
return true;
|
||||
}
|
||||
|
||||
// match gated_delta_net + the strided cpy that scatters its state snapshots into the cache
|
||||
// (slot i -> rollback group i, slot 0 newest), so the kernel can write them and skip the cpy.
|
||||
static int ggml_cuda_try_gdn_cache_fusion(
|
||||
const ggml_cgraph * cgraph, int node_idx, ggml_cuda_gated_delta_net_fused_cache & fused_state_cpy) {
|
||||
const ggml_tensor * gdn = cgraph->nodes[node_idx];
|
||||
// the kernel skips the snapshot tail, so the gdn output must not be a graph output
|
||||
if (gdn->op != GGML_OP_GATED_DELTA_NET || gdn->type != GGML_TYPE_F32 ||
|
||||
(gdn->flags & GGML_TENSOR_FLAG_OUTPUT)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const ggml_tensor * src_v = gdn->src[2];
|
||||
const int64_t S_v = src_v->ne[0];
|
||||
const int64_t H = src_v->ne[1];
|
||||
const int64_t n_tokens = src_v->ne[2];
|
||||
const int64_t n_seqs = src_v->ne[3];
|
||||
const int64_t D = S_v * S_v * H;
|
||||
const int64_t K = ggml_get_op_params_i32(gdn, 0); // snapshot slot count
|
||||
const int64_t n_written = std::min<int64_t>(n_tokens, K); // newest n_written slots are written
|
||||
|
||||
// snapshot tail starts right after the attention scores
|
||||
const size_t tail_off = ggml_row_size(GGML_TYPE_F32, S_v * H * n_tokens * n_seqs);
|
||||
|
||||
// snapshot cpy is the first real node after the gdn (skip views/no-ops)
|
||||
const ggml_tensor * cpy = nullptr;
|
||||
int skip = 0;
|
||||
for (int j = node_idx + 1; j < cgraph->n_nodes && cpy == nullptr; ++j) {
|
||||
const ggml_tensor * n = cgraph->nodes[j];
|
||||
if (ggml_cuda_is_view_or_noop(n)) {
|
||||
continue;
|
||||
}
|
||||
if (n->op != GGML_OP_CPY || (n->flags & GGML_TENSOR_FLAG_OUTPUT)) {
|
||||
return 0;
|
||||
}
|
||||
cpy = n;
|
||||
skip = j - node_idx;
|
||||
}
|
||||
if (cpy == nullptr) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const ggml_tensor * src = cpy->src[0]; // view of the gdn snapshot tail
|
||||
const ggml_tensor * dst = cpy->src[1]; // cache view the kernel writes to
|
||||
|
||||
// src must be this gdn's snapshot tail (contiguous, at the tail offset)
|
||||
if (src->op != GGML_OP_VIEW || src->view_src != gdn || src->view_offs != tail_off ||
|
||||
!ggml_is_contiguous(src)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// dst is the [D, n_seqs, n_written] cache view; require nb[1] == D (the per-seq stride the kernel
|
||||
// assumes). ggml_cpy pins src to the same element count.
|
||||
const std::array<int64_t, GGML_MAX_DIMS> expected_ne = { D, n_seqs, n_written, 1 };
|
||||
if (dst->op != GGML_OP_VIEW || dst->type != GGML_TYPE_F32 || dst->data == nullptr ||
|
||||
!std::equal(expected_ne.begin(), expected_ne.end(), dst->ne) ||
|
||||
dst->nb[0] != ggml_type_size(GGML_TYPE_F32) || dst->nb[1] != (size_t) ggml_row_size(GGML_TYPE_F32, D)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
fused_state_cpy.data = (float *) dst->data; // rollback group 0 (newest)
|
||||
fused_state_cpy.slot_stride = K > 1 ? (int64_t) (dst->nb[2] / sizeof(float)) : 0;
|
||||
return skip;
|
||||
}
|
||||
|
||||
static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
|
||||
args.sigmoid = false;
|
||||
args.softmax = false;
|
||||
@@ -3844,6 +3913,20 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
|
||||
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
// gated_delta_net -> cpy: scatter recurrent-state snapshots into the cache
|
||||
if (node->op == GGML_OP_GATED_DELTA_NET) {
|
||||
ggml_cuda_gated_delta_net_fused_cache fused_state_cpy;
|
||||
const int nodes_to_skip = ggml_cuda_try_gdn_cache_fusion(cgraph, i, fused_state_cpy);
|
||||
if (nodes_to_skip > 0) {
|
||||
#ifdef GGML_CUDA_DEBUG
|
||||
GGML_LOG_INFO("%s: fused gated_delta_net snapshot copies for %s (skipped %d nodes)\n",
|
||||
__func__, node->name, nodes_to_skip);
|
||||
#endif
|
||||
ggml_cuda_op_gated_delta_net_fused_cache(*cuda_ctx, node, fused_state_cpy);
|
||||
return nodes_to_skip;
|
||||
}
|
||||
}
|
||||
|
||||
//topk-moe
|
||||
if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
|
||||
cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
|
||||
@@ -4372,7 +4455,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
|
||||
#endif
|
||||
prev_i = i;
|
||||
|
||||
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
|
||||
if (ggml_cuda_is_view_or_noop(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
@@ -368,5 +368,12 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
|
||||
return true;
|
||||
}
|
||||
|
||||
// gfx900 (Vega 10) lacks native dp4a, loses to dequant + hipBLAS
|
||||
// for dense matrices; keep MMQ only for MoE, where the
|
||||
// hipBLAS path is much slower.
|
||||
if (cc == GGML_CUDA_CC_VEGA) {
|
||||
return n_experts > 0;
|
||||
}
|
||||
|
||||
return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
|
||||
}
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 16, 2);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 32, 2);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 4, 2);
|
||||
|
||||
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);
|
||||
DECL_FATTN_MMA_F16_CASE(512, 512, 8, 2);
|
||||
|
||||
@@ -92,7 +92,7 @@ for ncols in [8, 16, 32, 64]:
|
||||
continue
|
||||
if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4
|
||||
continue
|
||||
if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4
|
||||
if head_size_kq == 512 and ncols2 not in (2, 4, 8): # Gemma 4 (+ MTP)
|
||||
continue
|
||||
if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash
|
||||
continue
|
||||
|
||||
@@ -312,6 +312,10 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<256, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 288: // StepFun 3.7
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<288, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
break;
|
||||
case 512:
|
||||
ggml_cuda_kernel_launch(topk_moe_cuda<512, has_bias>, launch_params,
|
||||
logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config);
|
||||
@@ -377,8 +381,10 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
|
||||
const ggml_tensor * weights,
|
||||
const ggml_tensor * logits,
|
||||
const ggml_tensor * ids) {
|
||||
// must match an instantiation of launch_topk_moe_cuda: a power of 2 up to 512,
|
||||
// or one of the non-power-of-2 expert counts of supported models
|
||||
const int n_expert = ids->nb[1] / ids->nb[0];
|
||||
if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
|
||||
if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 288 && n_expert != 576) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
|
||||
include(ExternalProject)
|
||||
|
||||
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
|
||||
option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF)
|
||||
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
|
||||
|
||||
add_library(htp_iface OBJECT
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
#include "htp-opnode.h"
|
||||
#include "htp-ops.h"
|
||||
#include "htp/matmul-ops.h"
|
||||
#include "htp/flash-attn-ops.h"
|
||||
#include "htp_iface.h"
|
||||
#include "htp-drv.h"
|
||||
|
||||
@@ -62,6 +63,7 @@ static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu)
|
||||
static int opt_hostbuf = 1; // hostbuf ON by default
|
||||
|
||||
static int opt_mm_select = 3; // 3 = HMX -> Tiled -> Flat -> CPU, 2 = Tiled -> Flat -> CPU, 1 = Flat -> CPU
|
||||
static int opt_fa_select = 2; // 2 = HMX -> HVX -> CPU, 1 = HVX -> CPU, 0 = CPU (unsupported)
|
||||
|
||||
// Default PMU events, if profiling with PMU (mode=2) is enabled
|
||||
// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html
|
||||
@@ -125,6 +127,11 @@ static const char * htp_event_name(uint16_t id) {
|
||||
case HTP_TRACE_EVT_HVX_W_DEQUANT: return "HVX_W_DEQUANT";
|
||||
case HTP_TRACE_EVT_HVX_W_PREP: return "HVX_W_PREP";
|
||||
case HTP_TRACE_EVT_HVX_O_PROC: return "HVX_O_PROC";
|
||||
case HTP_TRACE_EVT_HVX_FA_QK: return "HVX_QK_FA";
|
||||
case HTP_TRACE_EVT_HVX_FA_SFM: return "HVX_SFM_FA";
|
||||
case HTP_TRACE_EVT_HVX_FA_Q_PREP: return "HVX_Q_PREP";
|
||||
case HTP_TRACE_EVT_HVX_FA_K_PREP: return "HVX_K_PREP";
|
||||
case HTP_TRACE_EVT_HVX_FA_V_PREP: return "HVX_V_PREP";
|
||||
case HTP_TRACE_EVT_HMX_COMP: return "HMX_COMP";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
@@ -1879,6 +1886,162 @@ ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) {
|
||||
|
||||
// ** backend interface
|
||||
|
||||
static bool ggml_hexagon_flash_attn_is_hmx_eligible(
|
||||
const struct ggml_hexagon_session * sess,
|
||||
const struct ggml_tensor * q,
|
||||
const struct ggml_tensor * k,
|
||||
const struct ggml_tensor * v,
|
||||
const struct ggml_tensor * sinks
|
||||
) {
|
||||
if (sess->n_hmx == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (opt_fa_select < 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (k->type != GGML_TYPE_F16 || v->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const uint32_t DK = q->ne[0];
|
||||
const uint32_t DV = v->ne[0];
|
||||
|
||||
if (DK % 64 != 0 || DV % 64 != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Fall back to HVX for small token counts if head dimension is small (DK <= 128)
|
||||
const uint32_t neq1 = q->ne[1];
|
||||
if (DK <= 128 && neq1 < 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_precompute_flash_attn_params(
|
||||
const struct ggml_hexagon_session * sess,
|
||||
const struct ggml_tensor * op,
|
||||
struct htp_fa_kernel_params * kparams
|
||||
) {
|
||||
if (opt_fa_select < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
memset(kparams, 0, sizeof(*kparams));
|
||||
|
||||
const struct ggml_tensor * q = op->src[0];
|
||||
const struct ggml_tensor * k = op->src[1];
|
||||
const struct ggml_tensor * v = op->src[2];
|
||||
const struct ggml_tensor * mask = op->src[3];
|
||||
const struct ggml_tensor * dst = op;
|
||||
|
||||
const uint32_t neq0 = q->ne[0]; // head_dim (DK)
|
||||
const uint32_t neq1 = q->ne[1]; // n_tokens
|
||||
const uint32_t neq2 = q->ne[2]; // n_heads
|
||||
|
||||
const uint32_t nek1 = k->ne[1]; // kv_len
|
||||
|
||||
const uint32_t nev0 = v->ne[0]; // head_dim (DV)
|
||||
|
||||
const uint32_t DK = neq0;
|
||||
const uint32_t DV = nev0;
|
||||
|
||||
const uint32_t n_kv_heads = k->ne[2];
|
||||
const uint32_t G = neq2 / n_kv_heads;
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
memcpy(&scale, &op->op_params[0], sizeof(float));
|
||||
memcpy(&max_bias, &op->op_params[1], sizeof(float));
|
||||
memcpy(&logit_softcap, &op->op_params[2], sizeof(float));
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
kparams->scale = scale;
|
||||
kparams->max_bias = max_bias;
|
||||
kparams->logit_softcap = logit_softcap;
|
||||
|
||||
kparams->is_q_fp32 = (q->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
kparams->is_dst_fp32 = (dst->type == GGML_TYPE_F32) ? 1 : 0;
|
||||
kparams->G = G;
|
||||
|
||||
const uint32_t n_head = q->ne[2];
|
||||
kparams->n_head_log2 = 1u << (uint32_t) std::floor(std::log2(n_head));
|
||||
kparams->m0 = std::pow(2.0f, -(max_bias) / kparams->n_head_log2);
|
||||
kparams->m1 = std::pow(2.0f, -(max_bias / 2.0f) / kparams->n_head_log2);
|
||||
|
||||
// Check HMX eligibility
|
||||
const struct ggml_tensor * sinks = op->src[4];
|
||||
if (ggml_hexagon_flash_attn_is_hmx_eligible(sess, q, k, v, sinks)) {
|
||||
size_t Br = 0, Bc = 0;
|
||||
int ret = hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, sess->vtcm_size, sess->n_threads);
|
||||
if (ret == 0) {
|
||||
kparams->kernel_type = HTP_FA_KERNEL_HMX;
|
||||
kparams->Br = Br;
|
||||
kparams->Bc = Bc;
|
||||
kparams->n_kv_blocks = (nek1 + Bc - 1) / Bc;
|
||||
kparams->n_threads = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? sess->n_threads : 1;
|
||||
|
||||
kparams->u.hmx.g_br = hex_align_up(G * Br, 32);
|
||||
kparams->u.hmx.pipeline = (kparams->n_kv_blocks >= 3 && sess->n_threads >= 2) ? 1 : 0;
|
||||
kparams->vtcm_size = hmx_fa_compute_vtcm_usage(G, DK, DV, Br, Bc, kparams->n_threads, kparams->u.hmx.pipeline != 0);
|
||||
|
||||
const size_t row_vec_bytes = hex_align_up(Bc * sizeof(uint16_t), 256);
|
||||
kparams->u.hmx.row_buf_stride = row_vec_bytes / 128; // HVX vector is 128 bytes
|
||||
|
||||
const size_t m_line_bytes = hex_align_up(Bc * sizeof(uint16_t), 128);
|
||||
kparams->u.hmx.mask_buf_row_stride = m_line_bytes / sizeof(uint16_t);
|
||||
kparams->u.hmx.mask_broadcast = (mask != nullptr && mask->ne[2] == 1) ? 1 : 0;
|
||||
kparams->u.hmx.div_G = init_fastdiv_values(G);
|
||||
if (mask) {
|
||||
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
kparams->qrows = 0;
|
||||
kparams->qrows_per_thread = 0;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to HVX
|
||||
kparams->kernel_type = HTP_FA_KERNEL_HVX;
|
||||
kparams->Br = 1;
|
||||
kparams->Bc = 64; // FLASH_ATTN_BLOCK_SIZE
|
||||
kparams->n_kv_blocks = (k->ne[1] + 64 - 1) / 64;
|
||||
kparams->n_threads = sess->n_threads;
|
||||
|
||||
const size_t size_q_row_padded = hex_round_up(q->ne[0] * (kparams->is_q_fp32 ? 4 : 2), 128);
|
||||
const size_t size_k_row_padded = hex_round_up(k->ne[0] * 2, 128);
|
||||
const size_t size_v_row_padded = hex_round_up(v->ne[0] * 2, 128);
|
||||
|
||||
kparams->vtcm_size = hvx_fa_compute_vtcm_usage(DK, DV, kparams->is_q_fp32 != 0, mask != nullptr, sess->n_threads);
|
||||
|
||||
kparams->u.hvx.size_q_row_padded = size_q_row_padded;
|
||||
kparams->u.hvx.size_k_row_padded = size_k_row_padded;
|
||||
kparams->u.hvx.size_v_row_padded = size_v_row_padded;
|
||||
kparams->u.hvx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
|
||||
kparams->u.hvx.src0_div1 = init_fastdiv_values(q->ne[1]);
|
||||
kparams->u.hvx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
|
||||
kparams->u.hvx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
|
||||
kparams->u.hvx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
|
||||
kparams->u.hvx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
|
||||
if (mask) {
|
||||
kparams->src3_div2 = init_fastdiv_values(mask->ne[2]);
|
||||
kparams->src3_div3 = init_fastdiv_values(mask->ne[3]);
|
||||
}
|
||||
|
||||
kparams->qrows = q->ne[1] * q->ne[2] * q->ne[3];
|
||||
kparams->qrows_per_thread = (kparams->qrows + sess->n_threads - 1) / sess->n_threads;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
|
||||
const struct ggml_tensor * src0 = op->src[0];
|
||||
@@ -1912,6 +2075,17 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess
|
||||
return false;
|
||||
}
|
||||
|
||||
struct htp_fa_kernel_params kparams;
|
||||
if (!ggml_hexagon_precompute_flash_attn_params(sess, op, &kparams)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((size_t) kparams.vtcm_size > sess->vtcm_size) {
|
||||
HEX_VERBOSE("ggml-hex: skip flash_attn_ext because VTCM needed (%d) > budget (%zu)\n",
|
||||
kparams.vtcm_size, sess->vtcm_size);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -2211,14 +2385,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
|
||||
kparams->kernel_type = (src1_nrows < (int) sess->n_threads) ? HTP_MM_KERNEL_HVX_QUANT_BLOCK : HTP_MM_KERNEL_HVX_QUANT_ROW;
|
||||
kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
|
||||
|
||||
size_t vtcm_src0_size = 0, vtcm_src1_size = 0;
|
||||
size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0;
|
||||
uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16;
|
||||
uint32_t best_n_prefetch = 2;
|
||||
size_t total_size = 0;
|
||||
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
|
||||
total_size = htp_mm_hvx_id_get_vtcm_sizes(
|
||||
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], d,
|
||||
&vtcm_src0_size, &vtcm_src1_size
|
||||
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
|
||||
);
|
||||
if (total_size <= vtcm_budget) {
|
||||
best_n_prefetch = d;
|
||||
@@ -2228,14 +2402,14 @@ static void ggml_hexagon_precompute_hvx_mm_params(
|
||||
if (best_n_prefetch == 2 && total_size > vtcm_budget) {
|
||||
total_size = htp_mm_hvx_id_get_vtcm_sizes(
|
||||
wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], 2,
|
||||
&vtcm_src0_size, &vtcm_src1_size
|
||||
&vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size
|
||||
);
|
||||
}
|
||||
kparams->n_prefetch = best_n_prefetch;
|
||||
kparams->vtcm_size = total_size;
|
||||
kparams->vtcm_src0_size = vtcm_src0_size;
|
||||
kparams->vtcm_src1_size = vtcm_src1_size;
|
||||
kparams->vtcm_dst_size = 0;
|
||||
kparams->vtcm_dst_size = vtcm_dst_size;
|
||||
} else {
|
||||
bool try_tiled = (k_align && opt_mm_select >= 2);
|
||||
if (try_tiled) {
|
||||
@@ -2441,11 +2615,12 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
size_t src3_sz_per_thread = 0;
|
||||
uint32_t best_n_prefetch = 16;
|
||||
|
||||
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
|
||||
size_t src1_sz = src1_sz_per_thread;
|
||||
|
||||
@@ -2453,13 +2628,10 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
best_n_prefetch = 2;
|
||||
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
|
||||
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
|
||||
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
|
||||
size_t src3_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
|
||||
|
||||
if (tiled_vtcm_size <= sess->vtcm_size) {
|
||||
best_n_prefetch = d;
|
||||
@@ -2471,9 +2643,6 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
}
|
||||
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
|
||||
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
src0_sz_per_thread = repacked_vtcm_size;
|
||||
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
|
||||
src3_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
|
||||
@@ -2492,7 +2661,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
|
||||
size_t src3_sz = src3_sz_per_thread * sess->n_threads;
|
||||
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz + quant_scratch_size;
|
||||
bool try_tiled = (opt_mm_select >= 2);
|
||||
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
|
||||
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
|
||||
@@ -2500,6 +2669,7 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
kparams->vtcm_src1_size = src1_sz;
|
||||
kparams->vtcm_src2_size = src2_sz;
|
||||
kparams->vtcm_src3_size = src3_sz;
|
||||
kparams->vtcm_dst_size = quant_scratch_size;
|
||||
kparams->vtcm_size = tiled_vtcm_size;
|
||||
kparams->n_prefetch = best_n_prefetch;
|
||||
} else {
|
||||
@@ -2510,7 +2680,8 @@ static void ggml_hexagon_precompute_fused_qkv_params(
|
||||
kparams->vtcm_src1_size = flat_src1_sz;
|
||||
kparams->vtcm_src2_size = src2_sz;
|
||||
kparams->vtcm_src3_size = src3_sz;
|
||||
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz;
|
||||
kparams->vtcm_dst_size = quant_scratch_size;
|
||||
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz + quant_scratch_size;
|
||||
kparams->n_prefetch = best_n_prefetch;
|
||||
}
|
||||
}
|
||||
@@ -2536,11 +2707,12 @@ static void ggml_hexagon_precompute_fused_ffn_params(
|
||||
size_t src2_sz_per_thread = 0;
|
||||
uint32_t best_n_prefetch = 16;
|
||||
|
||||
size_t quant_scratch_size = hex_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * sess->n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128);
|
||||
size_t src1_sz = src1_sz_per_thread;
|
||||
|
||||
@@ -2548,12 +2720,9 @@ static void ggml_hexagon_precompute_fused_ffn_params(
|
||||
best_n_prefetch = 2;
|
||||
for (uint32_t d = max_prefetch; d >= 2; d /= 2) {
|
||||
size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
size_t src0_sz = repacked_vtcm_size * sess->n_threads;
|
||||
size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
|
||||
|
||||
if (tiled_vtcm_size <= sess->vtcm_size) {
|
||||
best_n_prefetch = d;
|
||||
@@ -2564,9 +2733,6 @@ static void ggml_hexagon_precompute_fused_ffn_params(
|
||||
}
|
||||
if (best_n_prefetch == 2 && src0_sz_per_thread == 0) {
|
||||
size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
src0_sz_per_thread = repacked_vtcm_size;
|
||||
src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128);
|
||||
}
|
||||
@@ -2582,13 +2748,14 @@ static void ggml_hexagon_precompute_fused_ffn_params(
|
||||
size_t src1_sz = src1_sz_per_thread;
|
||||
size_t src2_sz = src2_sz_per_thread * sess->n_threads;
|
||||
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz;
|
||||
size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + quant_scratch_size;
|
||||
bool try_tiled = (opt_mm_select >= 2);
|
||||
if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) {
|
||||
kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW;
|
||||
kparams->vtcm_src0_size = src0_sz;
|
||||
kparams->vtcm_src1_size = src1_sz;
|
||||
kparams->vtcm_src2_size = src2_sz;
|
||||
kparams->vtcm_dst_size = quant_scratch_size;
|
||||
kparams->vtcm_size = tiled_vtcm_size;
|
||||
kparams->n_prefetch = best_n_prefetch;
|
||||
} else {
|
||||
@@ -2598,7 +2765,8 @@ static void ggml_hexagon_precompute_fused_ffn_params(
|
||||
kparams->vtcm_src0_size = src0_sz;
|
||||
kparams->vtcm_src1_size = flat_src1_sz;
|
||||
kparams->vtcm_src2_size = src2_sz;
|
||||
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz;
|
||||
kparams->vtcm_dst_size = quant_scratch_size;
|
||||
kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + quant_scratch_size;
|
||||
kparams->n_prefetch = best_n_prefetch;
|
||||
}
|
||||
}
|
||||
@@ -3243,7 +3411,7 @@ static inline bool op_is_compute(ggml_tensor *node)
|
||||
return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
|
||||
}
|
||||
|
||||
static bool is_hmx_eligible(const ggml_tensor * t) {
|
||||
static bool mm_is_hmx_eligible(const ggml_tensor * t) {
|
||||
if (opt_nhmx == 0) { return false; }
|
||||
|
||||
const ggml_tensor * src0 = t->src[0];
|
||||
@@ -3262,7 +3430,7 @@ static bool is_hmx_eligible(const ggml_tensor * t) {
|
||||
static bool is_mergeable_mul_mat(const ggml_tensor * t) {
|
||||
if (!t || t->op != GGML_OP_MUL_MAT) return false;
|
||||
if (t->src[1]->type != GGML_TYPE_F32) return false;
|
||||
return ggml_is_quantized(t->src[0]->type) && !is_hmx_eligible(t);
|
||||
return ggml_is_quantized(t->src[0]->type) && !mm_is_hmx_eligible(t);
|
||||
}
|
||||
|
||||
static bool is_mergeable_mul_mat_pair(const ggml_tensor * n1, const ggml_tensor * n2) {
|
||||
@@ -3357,6 +3525,26 @@ static bool try_fuse_node(const ggml_hexagon_session * sess, const ggml_cgraph *
|
||||
}
|
||||
}
|
||||
|
||||
if (n->op == GGML_OP_MUL_MAT && next_node) {
|
||||
if (next_node->op == GGML_OP_ADD && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
|
||||
if (next_node->src[0] == n || next_node->src[1] == n) {
|
||||
struct htp_mm_kernel_params kparams;
|
||||
ggml_hexagon_precompute_matmul_params(sess, n->src[0], n->src[1], next_node, &kparams);
|
||||
if ((size_t)kparams.vtcm_size <= sess->vtcm_size) {
|
||||
htp_opnode node(n, {}, HTP_OP_MUL_MAT_ADD);
|
||||
node.add_fused(next_node);
|
||||
memcpy(node.kernel_params, &kparams, sizeof(kparams));
|
||||
nodes.push_back(std::move(node));
|
||||
i += 1;
|
||||
return true;
|
||||
} else {
|
||||
HEX_VERBOSE("ggml-hex: skip MUL_MAT_ADD fusion because VTCM needed (%d) > budget (%zu)\n",
|
||||
kparams.vtcm_size, sess->vtcm_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -3393,6 +3581,11 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
|
||||
node.node->src[0], node.node->src[1], node.node,
|
||||
(struct htp_mm_kernel_params *)node.kernel_params
|
||||
);
|
||||
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
|
||||
ggml_hexagon_precompute_flash_attn_params(sess,
|
||||
node.node,
|
||||
(struct htp_fa_kernel_params *)node.kernel_params
|
||||
);
|
||||
}
|
||||
computed_nodes.push_back(std::move(node));
|
||||
}
|
||||
@@ -4079,6 +4272,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX");
|
||||
const char * str_nhmx = getenv("GGML_HEXAGON_NHMX");
|
||||
const char * str_mm_select = getenv("GGML_HEXAGON_MM_SELECT");
|
||||
const char * str_fa_select = getenv("GGML_HEXAGON_FA_SELECT");
|
||||
const char * str_ndev = getenv("GGML_HEXAGON_NDEV");
|
||||
const char * str_arch = getenv("GGML_HEXAGON_ARCH");
|
||||
const char * str_vmem = getenv("GGML_HEXAGON_VMEM");
|
||||
@@ -4120,6 +4314,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
|
||||
opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx;
|
||||
opt_nhmx = str_nhmx ? atoi(str_nhmx) : (str_use_hmx ? atoi(str_use_hmx) : opt_nhmx);
|
||||
opt_mm_select = str_mm_select ? atoi(str_mm_select) : opt_mm_select;
|
||||
opt_fa_select = str_fa_select ? atoi(str_fa_select) : opt_fa_select;
|
||||
opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev;
|
||||
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf;
|
||||
opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf;
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <stdio.h>
|
||||
#include "htp-ops.h"
|
||||
#include "htp/matmul-ops.h"
|
||||
#include "htp/flash-attn-ops.h"
|
||||
|
||||
struct htp_opnode {
|
||||
ggml_tensor * node = nullptr;
|
||||
@@ -335,7 +336,8 @@ struct htp_opformat {
|
||||
}
|
||||
void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) {
|
||||
if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID ||
|
||||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) {
|
||||
node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN ||
|
||||
node.opcode == HTP_OP_MUL_MAT_ADD) {
|
||||
const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params;
|
||||
const char * path = "unknown";
|
||||
int32_t type = kparams->kernel_type;
|
||||
@@ -350,6 +352,16 @@ struct htp_opformat {
|
||||
path = "hvx-flat";
|
||||
}
|
||||
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
|
||||
} else if (node.opcode == HTP_OP_FLASH_ATTN_EXT) {
|
||||
const auto * kparams = (const struct htp_fa_kernel_params *) node.kernel_params;
|
||||
const char * path = "unknown";
|
||||
int32_t type = kparams->kernel_type;
|
||||
if (type == HTP_FA_KERNEL_HMX) {
|
||||
path = kparams->u.hmx.pipeline ? "hmx-pipe" : "hmx-seq";
|
||||
} else if (type == HTP_FA_KERNEL_HVX) {
|
||||
path = "hvx";
|
||||
}
|
||||
snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size);
|
||||
} else {
|
||||
snprintf(str, max_size, "----");
|
||||
}
|
||||
|
||||
@@ -20,9 +20,6 @@ add_library(${HTP_LIB} SHARED
|
||||
worker-pool.c
|
||||
hex-dma.c
|
||||
hmx-queue.c
|
||||
flash-attn-ops.c
|
||||
hmx-flash-attn-ops.c
|
||||
matmul-ops.c
|
||||
binary-ops.c
|
||||
unary-ops.c
|
||||
sum-rows-ops.c
|
||||
@@ -42,16 +39,14 @@ add_library(${HTP_LIB} SHARED
|
||||
solve-tri-ops.c
|
||||
gated-delta-net-ops.c
|
||||
pad-ops.c
|
||||
matmul-ops.c
|
||||
flash-attn-ops.c
|
||||
)
|
||||
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1>
|
||||
$<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,>)
|
||||
|
||||
if (GGML_HEXAGON_FA_EXP2_HF)
|
||||
message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)")
|
||||
target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1)
|
||||
endif()
|
||||
|
||||
build_idl(htp_iface.idl ${HTP_LIB})
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,253 @@
|
||||
#ifndef HTP_FLASH_ATTN_OPS_H
|
||||
#define HTP_FLASH_ATTN_OPS_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "hex-fastdiv.h"
|
||||
#include "hex-common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Tile constants (mirrored from hmx-utils.h for use on host side if needed)
|
||||
#define HMX_FP16_TILE_N_ROWS 32
|
||||
#define HMX_FP16_TILE_N_COLS 32
|
||||
#define HMX_FP16_TILE_N_ELMS 1024
|
||||
#define HMX_FP16_TILE_SIZE 2048
|
||||
#define HVX_FA_DMA_CACHE_SIZE 128
|
||||
#define HMX_FA_DMA_CACHE_SIZE 4
|
||||
|
||||
#define HTP_FA_M_INITIAL_VAL -10000.0f
|
||||
|
||||
enum htp_fa_kernel_type {
|
||||
HTP_FA_KERNEL_UNSUPPORTED = 0,
|
||||
HTP_FA_KERNEL_HVX,
|
||||
HTP_FA_KERNEL_HMX
|
||||
};
|
||||
|
||||
struct htp_fa_kernel_params {
|
||||
uint8_t kernel_type; // enum htp_fa_kernel_type
|
||||
uint8_t is_q_fp32; // 1 = Q type is F32, 0 = F16
|
||||
uint8_t is_dst_fp32; // 1 = dst type is F32, 0 = F16
|
||||
uint8_t n_threads; // Number of threads to run
|
||||
|
||||
// Common parameters
|
||||
uint16_t Br;
|
||||
uint16_t Bc;
|
||||
uint16_t n_kv_blocks; // also HVX's n_blocks
|
||||
uint16_t G; // GQA factor (n_heads / n_kv_heads)
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
uint32_t vtcm_size;
|
||||
|
||||
uint32_t qrows;
|
||||
uint32_t qrows_per_thread;
|
||||
float m0;
|
||||
float m1;
|
||||
uint32_t n_head_log2;
|
||||
|
||||
struct fastdiv_values src3_div2;
|
||||
struct fastdiv_values src3_div3;
|
||||
|
||||
union {
|
||||
struct {
|
||||
uint32_t g_br;
|
||||
uint32_t row_buf_stride;
|
||||
uint32_t mask_buf_row_stride;
|
||||
int32_t mask_broadcast;
|
||||
int32_t pipeline;
|
||||
struct fastdiv_values div_G;
|
||||
} hmx;
|
||||
struct {
|
||||
uint32_t size_q_row_padded;
|
||||
uint32_t size_k_row_padded;
|
||||
uint32_t size_v_row_padded;
|
||||
struct fastdiv_values src0_div21;
|
||||
struct fastdiv_values src0_div1;
|
||||
struct fastdiv_values broadcast_rk2;
|
||||
struct fastdiv_values broadcast_rk3;
|
||||
struct fastdiv_values broadcast_rv2;
|
||||
struct fastdiv_values broadcast_rv3;
|
||||
} hvx;
|
||||
} u;
|
||||
};
|
||||
|
||||
#if defined(__cplusplus)
|
||||
static_assert(sizeof(struct htp_fa_kernel_params) <= 128, "htp_fa_kernel_params is too large for kernel_params blob");
|
||||
#endif
|
||||
|
||||
// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration.
|
||||
// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions.
|
||||
// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales
|
||||
// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax.
|
||||
static inline size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) {
|
||||
const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS);
|
||||
const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK]
|
||||
const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong
|
||||
const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf
|
||||
const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf
|
||||
const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved
|
||||
const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved
|
||||
const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc]
|
||||
const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br]
|
||||
const size_t col_vec_size = hex_align_up(g_br * sizeof(float), 256); // m, l, etc.
|
||||
const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256);
|
||||
const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128);
|
||||
const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096) * HMX_FA_DMA_CACHE_SIZE;
|
||||
const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128);
|
||||
|
||||
return q_tile_size * 1 // Q tiles
|
||||
+ o_tile_size * 2 // O ping-pong
|
||||
+ k_dma_size * 2 // K DMA x2
|
||||
+ v_dma_size * 2 // V DMA x2
|
||||
+ k_tile_size * 1 // K tiles
|
||||
+ v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining)
|
||||
+ s_tile_size * 2 // S + P
|
||||
+ d_tile_size * 1 // D (diagonal matrix)
|
||||
+ col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum
|
||||
+ row_vec_size * 2 * n_threads // per-thread softmax row scratch
|
||||
+ m_buf_size * 1 // mask VTCM buffer [Br rows]
|
||||
+ slopes_size // Slopes
|
||||
+ 256 * 2; // HMX scales (id + qk)
|
||||
}
|
||||
|
||||
#define FA_HVX_BLOCK_SIZE 64
|
||||
|
||||
static inline size_t hvx_fa_compute_vtcm_usage(size_t DK, size_t DV, bool is_q_fp32, bool has_mask, size_t n_threads) {
|
||||
const size_t size_q_row_padded = hex_round_up(DK * (is_q_fp32 ? 4 : 2), 128);
|
||||
const size_t size_k_row_padded = hex_round_up(DK * sizeof(__fp16), 128);
|
||||
const size_t size_v_row_padded = hex_round_up(DV * sizeof(__fp16), 128);
|
||||
|
||||
const size_t size_q_block = size_q_row_padded * 1;
|
||||
const size_t size_k_block = size_k_row_padded * FA_HVX_BLOCK_SIZE;
|
||||
const size_t size_v_block = size_v_row_padded * FA_HVX_BLOCK_SIZE;
|
||||
const size_t size_m_block = hex_round_up(FA_HVX_BLOCK_SIZE * sizeof(__fp16), 128);
|
||||
const size_t size_vkq_acc = hex_round_up(DV * sizeof(float), 128);
|
||||
|
||||
const size_t size_per_thread = size_q_block * 1
|
||||
+ size_k_block * 2
|
||||
+ size_v_block * 2
|
||||
+ (has_mask ? size_m_block * HVX_FA_DMA_CACHE_SIZE : 0)
|
||||
+ size_vkq_acc;
|
||||
|
||||
return size_per_thread * n_threads;
|
||||
}
|
||||
|
||||
#define FA_MIN_KV_BLOCKS 3
|
||||
|
||||
// Cost-based (Br, Bc) search for flash attention with pipeline constraint.
|
||||
static inline int hmx_fa_find_chunk_size(size_t * Br_out,
|
||||
size_t * Bc_out,
|
||||
size_t gqa_factor,
|
||||
size_t DK,
|
||||
size_t DV,
|
||||
size_t qo_len,
|
||||
size_t kv_len,
|
||||
size_t vtcm_budget,
|
||||
size_t n_threads) {
|
||||
const size_t T = HMX_FP16_TILE_N_ROWS; // 32
|
||||
const size_t br_unit = hmx_ceil_div(T, gqa_factor);
|
||||
const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64
|
||||
const size_t fp16 = sizeof(__fp16);
|
||||
const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2);
|
||||
|
||||
// Approximate per-unit VTCM costs (without per-buffer alignment padding).
|
||||
const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * sizeof(float); // Q + O*2 + 4 col vectors
|
||||
const size_t per_gbr2 = fp16; // D diagonal matrix
|
||||
const size_t per_bc =
|
||||
3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs
|
||||
const size_t per_gbr_bc = 2 * fp16; // S + P
|
||||
|
||||
const size_t overhead = 256 * 2 + 13 * 4096;
|
||||
|
||||
if (vtcm_budget <= overhead) {
|
||||
return -1;
|
||||
}
|
||||
const size_t usable = vtcm_budget - overhead;
|
||||
|
||||
// Br_max: largest Br aligned to br_unit that does not exceed qo_len.
|
||||
const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit;
|
||||
|
||||
// Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS.
|
||||
// Only relax when kv_len is too short to form enough blocks.
|
||||
const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) :
|
||||
(kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit);
|
||||
// Cost coefficients calibrated from profiling
|
||||
const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store
|
||||
const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers
|
||||
|
||||
size_t best_cost = SIZE_MAX, best_mn = 0;
|
||||
size_t best_Br = 0, best_Bc = 0;
|
||||
|
||||
for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) {
|
||||
const size_t g_br = hex_align_up(gqa_factor * Br, T);
|
||||
|
||||
// g_br-dependent VTCM cost: g_br * per_gbr + g_br*g_br * per_gbr2
|
||||
const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2;
|
||||
if (gbr_cost >= usable) {
|
||||
if (Br == br_unit) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Analytically solve for max Bc:
|
||||
// remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE)
|
||||
// The Br * fp16 term accounts for the VTCM mask buffer [Br * Bc].
|
||||
const size_t remain = usable - gbr_cost;
|
||||
const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16 * HMX_FA_DMA_CACHE_SIZE;
|
||||
size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit);
|
||||
if (Bc < bc_unit) {
|
||||
if (Br == br_unit) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Exact VTCM verification (alignment padding may push over budget)
|
||||
while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) {
|
||||
Bc -= bc_unit;
|
||||
}
|
||||
if (Bc < bc_unit) {
|
||||
if (Br == br_unit) {
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const size_t q_blocks = (qo_len + Br - 1) / Br;
|
||||
const size_t kv_blocks = (kv_len + Bc - 1) / Bc;
|
||||
const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed);
|
||||
const size_t mn = Br * Bc;
|
||||
|
||||
if (cost < best_cost || (cost == best_cost && mn > best_mn)) {
|
||||
best_cost = cost;
|
||||
best_mn = mn;
|
||||
best_Br = Br;
|
||||
best_Bc = Bc;
|
||||
}
|
||||
|
||||
if (Br == br_unit) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_Br == 0) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
*Br_out = best_Br;
|
||||
*Bc_out = best_Bc;
|
||||
return 0;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* HTP_FLASH_ATTN_OPS_H */
|
||||
@@ -138,27 +138,28 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
|
||||
}
|
||||
|
||||
dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx];
|
||||
desc->next = NULL;
|
||||
desc->desc_size = 0; // 1D mode
|
||||
desc->src_bypass = dma_src_l2_bypass_on;
|
||||
desc->dst_bypass = dma_dst_l2_bypass_on;
|
||||
desc->order = 0;
|
||||
desc->done = 0;
|
||||
desc->src = (void *) dptr.src;
|
||||
desc->dst = (void *) dptr.dst;
|
||||
desc->size = size;
|
||||
desc->src = (void *) dptr.src;
|
||||
desc->dst = (void *) dptr.dst;
|
||||
desc->size = size;
|
||||
|
||||
q->dptr[q->push_idx] = dptr;
|
||||
|
||||
if (size) {
|
||||
desc->next = NULL;
|
||||
desc->desc_size = 0; // 1D mode
|
||||
desc->src_bypass = dma_src_l2_bypass_on;
|
||||
desc->dst_bypass = dma_dst_l2_bypass_on;
|
||||
desc->order = 0;
|
||||
desc->done = 0;
|
||||
|
||||
htp_trace_event_start(q->trace, HTP_TRACE_EVT_DMA, q->push_idx);
|
||||
dmlink(q->tail, desc);
|
||||
q->tail = (dma_descriptor_2d *) desc;
|
||||
} else {
|
||||
desc->done = 1;
|
||||
desc->desc_size = 0;
|
||||
desc->done = 1;
|
||||
}
|
||||
|
||||
// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
|
||||
q->push_idx = (q->push_idx + 1) & q->idx_mask;
|
||||
return true;
|
||||
}
|
||||
@@ -320,7 +321,7 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
|
||||
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
|
||||
}
|
||||
|
||||
#define DMA_CACHE_MAX_SIZE 64U
|
||||
#define DMA_CACHE_MAX_SIZE 256U
|
||||
|
||||
typedef struct {
|
||||
uint8_t *base;
|
||||
@@ -352,20 +353,19 @@ static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * sr
|
||||
if (c->src[i] == (uint32_t) src) {
|
||||
c->age[i] = 0;
|
||||
dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
|
||||
// FARF(ERROR, "dma-cache: found %p", src);
|
||||
} else {
|
||||
c->age[i]++;
|
||||
if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
|
||||
}
|
||||
}
|
||||
if (!dst) {
|
||||
// FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
|
||||
c->age[o_idx] = 0;
|
||||
c->src[o_idx] = (uint32_t) src;
|
||||
dst = c->base + o_idx * c->line_size; // normal nrows dma
|
||||
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
|
||||
}
|
||||
|
||||
return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
|
||||
return dma_queue_push_single_1d(q, dma_make_ptr(dst, src), 0);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
#ifndef HMX_FA_KERNELS_H
|
||||
#define HMX_FA_KERNELS_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
#include "hvx-utils.h"
|
||||
#include "hmx-utils.h"
|
||||
|
||||
// HMX-specific parameters, offsets and inner kernels for Flash Attention
|
||||
|
||||
// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6
|
||||
// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile
|
||||
static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = {
|
||||
0 * 136, 0 * 136 + 6,
|
||||
1 * 136, 1 * 136 + 6,
|
||||
2 * 136, 2 * 136 + 6,
|
||||
3 * 136, 3 * 136 + 6,
|
||||
4 * 136, 4 * 136 + 6,
|
||||
5 * 136, 5 * 136 + 6,
|
||||
6 * 136, 6 * 136 + 6,
|
||||
7 * 136, 7 * 136 + 6,
|
||||
8 * 136, 8 * 136 + 6,
|
||||
9 * 136, 9 * 136 + 6,
|
||||
10 * 136, 10 * 136 + 6,
|
||||
11 * 136, 11 * 136 + 6,
|
||||
12 * 136, 12 * 136 + 6,
|
||||
13 * 136, 13 * 136 + 6,
|
||||
14 * 136, 14 * 136 + 6,
|
||||
15 * 136, 15 * 136 + 6,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
0, 0,
|
||||
};
|
||||
// Inner HMX tile computation kernels
|
||||
|
||||
static inline void hmx_fa_qk_dot_tile(
|
||||
const __fp16 * row_tiles,
|
||||
const __fp16 * col_tiles,
|
||||
__fp16 * out_tile,
|
||||
size_t n_dot_tiles
|
||||
) {
|
||||
for (size_t k = 0; k < n_dot_tiles; ++k) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047);
|
||||
row_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
col_tiles += HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
Q6_mxmem_AR_after_hf(out_tile, 0);
|
||||
}
|
||||
|
||||
static inline void hmx_fa_o_update_tile(
|
||||
const __fp16 * d_diag,
|
||||
const __fp16 * o_rc,
|
||||
const __fp16 * p_tile_in,
|
||||
const __fp16 * v_tile_in,
|
||||
__fp16 * o_tile_out,
|
||||
size_t n_col_tiles
|
||||
) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
|
||||
|
||||
for (size_t k = 0; k < n_col_tiles; ++k) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047);
|
||||
p_tile_in += HMX_FP16_TILE_N_ELMS;
|
||||
v_tile_in += HMX_FP16_TILE_N_ELMS;
|
||||
}
|
||||
|
||||
Q6_mxmem_AR_after_hf(o_tile_out, 0);
|
||||
}
|
||||
|
||||
static inline void hmx_fa_o_norm_tile(
|
||||
const __fp16 * d_diag,
|
||||
const __fp16 * o_rc,
|
||||
__fp16 * o_out
|
||||
) {
|
||||
Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047);
|
||||
Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047);
|
||||
Q6_mxmem_AR_after_hf(o_out, 0);
|
||||
}
|
||||
|
||||
#endif /* HMX_FA_KERNELS_H */
|
||||
File diff suppressed because it is too large
Load Diff
@@ -712,7 +712,17 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job,
|
||||
|
||||
// output : fp16 -> f32p
|
||||
|
||||
static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, uint32_t start_row, uint32_t n_rows, uint32_t n_cols, uint32_t dst_stride, uint32_t dst_cols) {
|
||||
static void transfer_output_chunk_fp16_to_fp32(
|
||||
float *restrict dst,
|
||||
const float *restrict src2,
|
||||
const __fp16 *restrict vtcm_src,
|
||||
uint32_t start_row,
|
||||
uint32_t n_rows,
|
||||
uint32_t n_cols,
|
||||
uint32_t dst_stride,
|
||||
uint32_t src2_stride,
|
||||
uint32_t dst_cols
|
||||
) {
|
||||
assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0);
|
||||
const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS;
|
||||
|
||||
@@ -727,6 +737,7 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
|
||||
const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile
|
||||
const __fp16 *row_base = vtcm_src + r0 * tile_row_stride;
|
||||
float *output_row_base = dst + r * dst_stride; // global memory row base for row r (and r+1)
|
||||
const float *src2_row_base = src2 ? (src2 + r * src2_stride) : NULL;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (size_t c = 0; c < limit_c_aligned; c += HTP_MM_HMX_TILE_N_COLS) {
|
||||
@@ -738,9 +749,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
|
||||
HVX_Vector *pv_out0 = (HVX_Vector *) (output_row_base + c + 0);
|
||||
HVX_Vector *pv_out1 = (HVX_Vector *) (output_row_base + c + dst_stride);
|
||||
|
||||
*pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
|
||||
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
|
||||
if (src2_row_base) {
|
||||
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
|
||||
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
|
||||
}
|
||||
*pv_out0 = v_out0;
|
||||
|
||||
if (r + 1 < n_rows) {
|
||||
*pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
|
||||
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
|
||||
if (src2_row_base) {
|
||||
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
|
||||
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
|
||||
}
|
||||
*pv_out1 = v_out1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -752,9 +774,20 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
|
||||
HVX_Vector v = ((const HVX_Vector *) tile)[r1];
|
||||
HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one);
|
||||
|
||||
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)));
|
||||
HVX_Vector v_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp));
|
||||
if (src2_row_base) {
|
||||
HVX_Vector v_src2_0 = hvx_vmemu(src2_row_base + c + 0);
|
||||
v_out0 = hvx_vec_add_f32_f32(v_out0, v_src2_0);
|
||||
}
|
||||
hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), v_out0);
|
||||
|
||||
if (r + 1 < n_rows) {
|
||||
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)));
|
||||
HVX_Vector v_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp));
|
||||
if (src2_row_base) {
|
||||
HVX_Vector v_src2_1 = hvx_vmemu(src2_row_base + c + src2_stride);
|
||||
v_out1 = hvx_vec_add_f32_f32(v_out1, v_src2_1);
|
||||
}
|
||||
hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), v_out1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -763,11 +796,13 @@ static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16
|
||||
typedef struct {
|
||||
const __fp16 *vtcm_src;
|
||||
float *dst;
|
||||
const float *src2;
|
||||
uint32_t n_tasks;
|
||||
uint32_t n_tot_chunks;
|
||||
uint32_t n_chunks_per_task;
|
||||
uint32_t n_cols;
|
||||
uint32_t dst_stride; // DDR row stride
|
||||
uint32_t src2_stride; // DDR row stride for residual
|
||||
uint32_t dst_cols; // Actual output columns
|
||||
struct htp_thread_trace * traces;
|
||||
} output_transfer_task_state_t;
|
||||
|
||||
@@ -42,14 +42,14 @@ static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VL
|
||||
// Full range: start_row=0, end_row=n_cols.
|
||||
static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
const __fp16 * restrict vtcm_src,
|
||||
int n_cols,
|
||||
int k,
|
||||
int src_stride,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
uint32_t n_cols,
|
||||
uint32_t k,
|
||||
size_t src_stride,
|
||||
uint32_t start_row,
|
||||
uint32_t end_row) {
|
||||
assert(k % HMX_FP16_TILE_N_COLS == 0);
|
||||
|
||||
const int n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||||
const uint32_t n_k_tiles = k / HMX_FP16_TILE_N_COLS;
|
||||
const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets);
|
||||
const HVX_Vector v_scat_step = Q6_V_vsplat_R(4);
|
||||
const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64);
|
||||
@@ -65,14 +65,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
|
||||
if (pair_scatter) {
|
||||
// Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter.
|
||||
const int c_step = 2 * HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
const uint32_t c_step = 2 * HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const uint32_t n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
for (uint32_t r = start_row; r < end_row; r += 2) {
|
||||
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
@@ -86,7 +86,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
assert(c_byte_step % 128 == 0);
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
for (uint32_t i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
@@ -95,7 +95,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
for (uint32_t i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step;
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0);
|
||||
Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero);
|
||||
@@ -105,14 +105,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
}
|
||||
} else {
|
||||
// Fallback: scatter one K-tile per call (region 2047, masked).
|
||||
const int c_step = HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const int n_c_iters = k / c_step;
|
||||
const uint32_t c_step = HMX_FP16_TILE_N_COLS;
|
||||
const size_t c_byte_step = (size_t) c_step * sizeof(__fp16);
|
||||
const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS;
|
||||
const uint32_t n_c_iters = k / c_step;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
const int ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
for (uint32_t r = start_row; r < end_row; r += 2) {
|
||||
const uint32_t ct = r / HMX_FP16_TILE_N_ROWS;
|
||||
const uint32_t local_r = r % HMX_FP16_TILE_N_ROWS;
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols;
|
||||
const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4));
|
||||
const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step);
|
||||
@@ -122,7 +122,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
if (p1) {
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
for (uint32_t i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
|
||||
HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
@@ -131,7 +131,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int i = 0; i < n_c_iters; ++i) {
|
||||
for (uint32_t i = 0; i < n_c_iters; ++i) {
|
||||
HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step;
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0);
|
||||
Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero);
|
||||
@@ -148,24 +148,24 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst,
|
||||
// Full range: start_row=0, end_row=n_rows.
|
||||
static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
|
||||
const __fp16 * restrict src,
|
||||
int n_rows,
|
||||
int head_dim,
|
||||
int src_stride,
|
||||
int n_row_tiles,
|
||||
int start_row,
|
||||
int end_row) {
|
||||
uint32_t n_rows,
|
||||
uint32_t head_dim,
|
||||
size_t src_stride,
|
||||
uint32_t n_row_tiles,
|
||||
uint32_t start_row,
|
||||
uint32_t end_row) {
|
||||
__builtin_assume(head_dim > 0);
|
||||
const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS;
|
||||
|
||||
for (int r = start_row; r < end_row; r += 2) {
|
||||
for (uint32_t r = start_row; r < end_row; r += 2) {
|
||||
const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows;
|
||||
|
||||
const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride);
|
||||
const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL;
|
||||
|
||||
// Row-pair invariants hoisted out of the c loop.
|
||||
const int r0 = r / HMX_FP16_TILE_N_ROWS;
|
||||
const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
|
||||
const uint32_t r0 = r / HMX_FP16_TILE_N_ROWS;
|
||||
const uint32_t r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2;
|
||||
|
||||
// tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0).
|
||||
// Each c step (+= 64) advances both by 2 dim-tiles worth of fp16.
|
||||
@@ -174,7 +174,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
|
||||
const size_t tb_step = 2 * tile_stride_elms;
|
||||
|
||||
if (pv_in1) {
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
for (uint32_t c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_Vector v1 = *pv_in1++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2);
|
||||
@@ -185,7 +185,7 @@ static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out,
|
||||
}
|
||||
} else {
|
||||
const HVX_Vector vzero = Q6_V_vzero();
|
||||
for (int c = 0; c < head_dim; c += 64) {
|
||||
for (uint32_t c = 0; c < head_dim; c += 64) {
|
||||
HVX_Vector v0 = *pv_in0++;
|
||||
HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2);
|
||||
((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp);
|
||||
|
||||
@@ -60,6 +60,7 @@ enum htp_op_code {
|
||||
HTP_OP_MUL_MAT_ID,
|
||||
HTP_OP_MUL_MAT_QKV,
|
||||
HTP_OP_MUL_MAT_FFN,
|
||||
HTP_OP_MUL_MAT_ADD,
|
||||
HTP_OP_RMS_NORM,
|
||||
HTP_OP_RMS_NORM_MUL,
|
||||
HTP_OP_UNARY_SILU,
|
||||
@@ -175,6 +176,11 @@ enum htp_trace_event_id {
|
||||
HTP_TRACE_EVT_HVX_W_DEQUANT = 23,
|
||||
HTP_TRACE_EVT_HVX_W_PREP = 24,
|
||||
HTP_TRACE_EVT_HVX_O_PROC = 25,
|
||||
HTP_TRACE_EVT_HVX_FA_QK = 26,
|
||||
HTP_TRACE_EVT_HVX_FA_SFM = 27,
|
||||
HTP_TRACE_EVT_HVX_FA_Q_PREP = 28,
|
||||
HTP_TRACE_EVT_HVX_FA_K_PREP = 29,
|
||||
HTP_TRACE_EVT_HVX_FA_V_PREP = 30,
|
||||
|
||||
HTP_TRACE_EVT_HMX_COMP = 40,
|
||||
};
|
||||
|
||||
@@ -134,16 +134,7 @@ static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1)
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) {
|
||||
HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
// replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0)
|
||||
const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY);
|
||||
HVX_VectorPred nan = hvx_vec_is_nan_f16(v);
|
||||
v = Q6_V_vmux_QVV(nan, neg_inf, v);
|
||||
#endif
|
||||
|
||||
return v;
|
||||
return Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1));
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ >= 79
|
||||
@@ -170,8 +161,6 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) {
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) {
|
||||
// This looks complicated.
|
||||
// Ideally should just be Q6_Vh_equals_Vhf(vin)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!)
|
||||
#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805
|
||||
#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408
|
||||
#define EXP_LOG2E_F 1.44269504f
|
||||
#define EXP_ONE (0x3f800000) // 1.0
|
||||
#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228
|
||||
#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN))
|
||||
@@ -213,4 +214,42 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict
|
||||
}
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_exp2_f16(HVX_Vector x_v) {
|
||||
const HVX_Vector zero_v = Q6_V_vzero();
|
||||
const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5
|
||||
|
||||
// Clamp input to prevent integer underflow in FP16-to-INT16 conversion
|
||||
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
|
||||
x_v = Q6_Vhf_vmax_VhfVhf(v_clamp_min, x_v);
|
||||
|
||||
// k = round_toward_neg_inf(x); f = (float)k; frac = x - f
|
||||
HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v));
|
||||
HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16
|
||||
HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16
|
||||
|
||||
HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16
|
||||
|
||||
// Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0
|
||||
HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4
|
||||
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3
|
||||
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2
|
||||
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1
|
||||
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16);
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0
|
||||
y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x
|
||||
y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0
|
||||
|
||||
// Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k
|
||||
y = Q6_Vhf_equals_Vqf16(y);
|
||||
HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11);
|
||||
y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp);
|
||||
HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp);
|
||||
y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10);
|
||||
return Q6_V_vmux_QVV(q_underflow, zero_v, y);
|
||||
}
|
||||
|
||||
#endif /* HVX_EXP_H */
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
#ifndef HVX_FA_KERNELS_H
|
||||
#define HVX_FA_KERNELS_H
|
||||
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include "hvx-utils.h"
|
||||
|
||||
// Little inner kernels for HVX
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
|
||||
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
|
||||
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
// This is a bit of a hack because the compiler is struggling to properly inline
|
||||
// the default hvx_vec_f32_to_f16 with output into the local array.
|
||||
static __attribute__((unused)) __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
|
||||
{
|
||||
*(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1);
|
||||
}
|
||||
|
||||
// Dot product of two F16 vectors, accumulating to float
|
||||
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
|
||||
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
|
||||
|
||||
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
|
||||
rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
|
||||
hvx_vec_store_u(r, 4, rsum);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
|
||||
const uint8_t * restrict x,
|
||||
const size_t stride_x,
|
||||
const size_t nvec,
|
||||
const size_t nloe) {
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16
|
||||
const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16
|
||||
const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16
|
||||
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
|
||||
|
||||
HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
||||
HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
||||
HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
||||
HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0));
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector y_hf = vy[i];
|
||||
HVX_Vector x0_hf = vx0[i];
|
||||
HVX_Vector x1_hf = vx1[i];
|
||||
HVX_Vector x2_hf = vx2[i];
|
||||
HVX_Vector x3_hf = vx3[i];
|
||||
|
||||
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
|
||||
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
|
||||
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
|
||||
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load x (fp16) and zero-out unused elements
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
|
||||
HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]);
|
||||
HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]);
|
||||
|
||||
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
|
||||
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
|
||||
rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf);
|
||||
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
|
||||
HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
|
||||
HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
|
||||
HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));
|
||||
|
||||
HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
|
||||
return hvx_vec_reduce_sum_f32x4(rsum0123);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
|
||||
const uint8_t * restrict x,
|
||||
const size_t stride_x,
|
||||
const size_t n,
|
||||
float s) {
|
||||
|
||||
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
const size_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector sums = Q6_V_vzero();
|
||||
const size_t stride_x_4 = stride_x * 4;
|
||||
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
|
||||
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
|
||||
HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32);
|
||||
sums = Q6_V_vmux_QVV(pred, sums, sums_x4);
|
||||
x += stride_x_4;
|
||||
}
|
||||
|
||||
return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x (F16) * s (F16)
|
||||
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, uint32_t n) {
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x;
|
||||
|
||||
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
|
||||
HVX_Vector * restrict vy = (HVX_Vector *) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S0 = hvx_vec_splat_f16(*s);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPair xy_p = vy_p[i];
|
||||
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
||||
|
||||
HVX_Vector xy = Q6_V_lo_W(xy_p);
|
||||
i = 2 * i; // index for vy
|
||||
|
||||
if (nloe >= VLEN_FP32) {
|
||||
vy[i] = xy;
|
||||
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
hvx_vec_store_a(&vy[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16)
|
||||
static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1,
|
||||
const __fp16 * restrict s0, const __fp16 * restrict s1, uint32_t n) {
|
||||
const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0;
|
||||
const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1;
|
||||
|
||||
HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y;
|
||||
HVX_Vector * restrict vy = (HVX_Vector *) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector S0 = hvx_vec_splat_f16(*s0);
|
||||
HVX_Vector S1 = hvx_vec_splat_f16(*s1);
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; ++i) {
|
||||
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
||||
vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPair xy_p = vy_p[i];
|
||||
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0);
|
||||
xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1);
|
||||
|
||||
HVX_Vector xy = Q6_V_lo_W(xy_p);
|
||||
i = 2 * i; // index for vy
|
||||
|
||||
if (nloe >= VLEN_FP32) {
|
||||
vy[i] = xy;
|
||||
nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
hvx_vec_store_a(&vy[i], nloe * 4, xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t n, HVX_Vector vs) {
|
||||
assert((size_t) dst % 128 == 0);
|
||||
assert((size_t) src % 128 == 0);
|
||||
|
||||
const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src;
|
||||
HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst;
|
||||
|
||||
const uint32_t nvec = n / VLEN_FP32;
|
||||
const uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
uint32_t i = 0;
|
||||
#pragma unroll(4)
|
||||
for (; i < nvec; ++i) {
|
||||
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
|
||||
}
|
||||
if (nloe) {
|
||||
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
|
||||
}
|
||||
}
|
||||
|
||||
#endif /* HVX_FA_KERNELS_H */
|
||||
@@ -256,7 +256,7 @@ static inline void quantize_f16_f16_flat_kernel(
|
||||
|
||||
// Dot kernels that consume flat (non-tiled) activations
|
||||
|
||||
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -312,10 +312,14 @@ static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const v
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -397,11 +401,19 @@ static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -464,10 +476,14 @@ static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const v
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -561,11 +577,19 @@ static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -620,10 +644,14 @@ static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const v
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -704,11 +732,19 @@ static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -765,10 +801,14 @@ static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -851,11 +891,19 @@ static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -921,10 +969,14 @@ static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
|
||||
|
||||
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -1019,6 +1071,441 @@ static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
|
||||
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#else
|
||||
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#endif
|
||||
|
||||
static inline void vec_dot_f32_f32_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
|
||||
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP32; // leftover elements
|
||||
|
||||
HVX_Vector rsum = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]);
|
||||
rsum = HVX_OP_ADD_F32(rsum, prod);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]);
|
||||
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
|
||||
HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf);
|
||||
rsum = HVX_OP_ADD_F32(rsum, prod);
|
||||
}
|
||||
|
||||
*s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum));
|
||||
}
|
||||
|
||||
static inline void vec_dot_f32_f32_aa_2x1(const uint32_t n, float * restrict s0,
|
||||
const void * restrict vx0, const void * restrict vx1,
|
||||
const void * restrict vy0) {
|
||||
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
||||
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
||||
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP32;
|
||||
uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector rsum0 = Q6_V_vzero();
|
||||
HVX_Vector rsum1 = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector y_sf = y[i];
|
||||
HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf);
|
||||
HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf);
|
||||
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
|
||||
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]);
|
||||
HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]);
|
||||
HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]);
|
||||
HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf);
|
||||
HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf);
|
||||
rsum0 = HVX_OP_ADD_F32(rsum0, prod0);
|
||||
rsum1 = HVX_OP_ADD_F32(rsum1, prod1);
|
||||
}
|
||||
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
|
||||
hvx_vec_store_u(s0, 8, rsum);
|
||||
}
|
||||
|
||||
static inline void vec_dot_f32_f32_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
|
||||
const void * restrict vx0, const void * restrict vx1,
|
||||
const void * restrict vy0, const void * restrict vy1) {
|
||||
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
||||
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
||||
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
|
||||
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP32;
|
||||
uint32_t nloe = n % VLEN_FP32;
|
||||
|
||||
HVX_Vector r0_c0_sum = Q6_V_vzero();
|
||||
HVX_Vector r0_c1_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_c0_sum = Q6_V_vzero();
|
||||
HVX_Vector r1_c1_sum = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector r0_sf = x0[i];
|
||||
HVX_Vector r1_sf = x1[i];
|
||||
HVX_Vector c0_sf = y0[i];
|
||||
HVX_Vector c1_sf = y1[i];
|
||||
|
||||
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
|
||||
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
|
||||
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
|
||||
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
|
||||
HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]);
|
||||
HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]);
|
||||
HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]);
|
||||
HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]);
|
||||
|
||||
r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf));
|
||||
r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf));
|
||||
r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf));
|
||||
r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf));
|
||||
}
|
||||
|
||||
// Reduce and store results
|
||||
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
||||
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
||||
|
||||
hvx_vec_store_u(s0, 8, r0_r1_c0_sum);
|
||||
hvx_vec_store_u(s1, 8, r0_r1_c1_sum);
|
||||
}
|
||||
|
||||
static inline void vec_dot_f32_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
|
||||
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
|
||||
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP32; // leftover elements
|
||||
|
||||
HVX_Vector rsum = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector x_sf = vx[i];
|
||||
HVX_Vector y_sf = vy[i];
|
||||
|
||||
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_Vector x_sf = vx[i];
|
||||
HVX_Vector y_sf = vy[i];
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
x_sf = Q6_V_vand_QV(bmask, x_sf);
|
||||
y_sf = Q6_V_vand_QV(bmask, y_sf);
|
||||
|
||||
rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf));
|
||||
}
|
||||
|
||||
rsum = hvx_vec_reduce_sum_f32(rsum);
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
#undef HVX_OP_ADD_F32
|
||||
#undef HVX_OP_MUL_F32
|
||||
|
||||
static inline void vec_dot_f16_f16_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
|
||||
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_VectorPair rsum_p = Q6_W_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
|
||||
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
|
||||
hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum));
|
||||
}
|
||||
|
||||
static inline void vec_dot_f16_f16_aa_2x1(const uint32_t n, float * restrict s0,
|
||||
const void * restrict vx0, const void * restrict vx1,
|
||||
const void * restrict vy0) {
|
||||
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
||||
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
||||
const HVX_Vector * restrict y = (const HVX_Vector *) vy0;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16;
|
||||
uint32_t nloe = n % VLEN_FP16;
|
||||
|
||||
HVX_VectorPair rsum0_p = Q6_W_vzero();
|
||||
HVX_VectorPair rsum1_p = Q6_W_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector y_hf = y[i];
|
||||
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf);
|
||||
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
|
||||
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]);
|
||||
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]);
|
||||
rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf);
|
||||
rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf);
|
||||
}
|
||||
|
||||
HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
|
||||
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
|
||||
HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1);
|
||||
hvx_vec_store_u(s0, 8, rsum);
|
||||
}
|
||||
|
||||
static inline void vec_dot_f16_f16_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1,
|
||||
const void * restrict vx0, const void * restrict vx1,
|
||||
const void * restrict vy0, const void * restrict vy1) {
|
||||
const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0;
|
||||
const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1;
|
||||
const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0;
|
||||
const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16;
|
||||
uint32_t nloe = n % VLEN_FP16;
|
||||
|
||||
// Row sums (sf) - 4 accumulators for 2x2 tile
|
||||
HVX_VectorPair r0_c0_sum_p = Q6_W_vzero();
|
||||
HVX_VectorPair r0_c1_sum_p = Q6_W_vzero();
|
||||
HVX_VectorPair r1_c0_sum_p = Q6_W_vzero();
|
||||
HVX_VectorPair r1_c1_sum_p = Q6_W_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_Vector r0_hf = x0[i];
|
||||
HVX_Vector r1_hf = x1[i];
|
||||
HVX_Vector c0_hf = y0[i];
|
||||
HVX_Vector c1_hf = y1[i];
|
||||
|
||||
// Compute 4 dot products: r0xc0, r0xc1, r1xc0, r1xc1
|
||||
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
|
||||
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
|
||||
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
|
||||
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
|
||||
HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]);
|
||||
HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]);
|
||||
HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]);
|
||||
HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]);
|
||||
|
||||
r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf);
|
||||
r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf);
|
||||
r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf);
|
||||
r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf);
|
||||
}
|
||||
|
||||
HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p)));
|
||||
HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p)));
|
||||
HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p)));
|
||||
HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p)));
|
||||
|
||||
// Reduce and store results
|
||||
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
|
||||
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
|
||||
|
||||
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
|
||||
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
|
||||
}
|
||||
|
||||
static inline void vec_dot_f16_f16_uu_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
const HVX_UVector * restrict x = (const HVX_UVector *) vx;
|
||||
const HVX_UVector * restrict y = (const HVX_UVector *) vy;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
HVX_Vector rsum = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(4)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]);
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]);
|
||||
HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
static inline void vec_dot_f16_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) {
|
||||
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
|
||||
const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y;
|
||||
|
||||
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||
uint32_t nloe = n % VLEN_FP16; // leftover elements
|
||||
|
||||
const HVX_Vector zero = Q6_V_vzero();
|
||||
|
||||
HVX_Vector rsum = Q6_V_vzero();
|
||||
|
||||
uint32_t i = 0;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (i = 0; i < nvec; i++) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
if (nloe) {
|
||||
// Load y (fp32) and convert into fp16
|
||||
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
|
||||
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
|
||||
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
|
||||
|
||||
// Load x (fp16)
|
||||
HVX_Vector x_hf = vx[i];
|
||||
|
||||
// Zero-out unused elements
|
||||
// Note that we need to clear both x and y because they may contain NANs
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
|
||||
x_hf = Q6_V_vand_QV(bmask, x_hf);
|
||||
y_hf = Q6_V_vand_QV(bmask, y_hf);
|
||||
|
||||
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
|
||||
|
||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
|
||||
}
|
||||
|
||||
// Convert into fp32 and reduce
|
||||
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
|
||||
hvx_vec_store_u(&s[0], 4, rsum);
|
||||
}
|
||||
|
||||
static inline void hvx_tensor_add_f32_grid(
|
||||
const struct htp_tensor * restrict dst,
|
||||
const struct htp_tensor * restrict src2,
|
||||
uint32_t start_row,
|
||||
uint32_t end_row,
|
||||
uint32_t start_col,
|
||||
uint32_t end_col,
|
||||
const struct fastdiv_values * div_ne11_12,
|
||||
const struct fastdiv_values * div_ne11
|
||||
) {
|
||||
if (start_row >= end_row || start_col >= end_col) return;
|
||||
const uint32_t nb1 = dst->nb[1]; // row stride in bytes
|
||||
|
||||
const uint32_t ne11 = dst->ne[1];
|
||||
const uint32_t ne12 = dst->ne[2];
|
||||
const uint32_t ne11_12 = ne11 * ne12;
|
||||
|
||||
const bool is_broadcast1 = (src2->ne[1] == 1);
|
||||
const bool is_broadcast2 = (src2->ne[2] == 1);
|
||||
const bool is_broadcast3 = (src2->ne[3] == 1);
|
||||
|
||||
for (uint32_t r = start_row; r < end_row; r++) {
|
||||
float * dst_row = (float *) ((uint8_t *) dst->data + r * nb1);
|
||||
|
||||
uint32_t i13 = fastdiv(r, div_ne11_12);
|
||||
uint32_t i12 = fastdiv(r - i13 * ne11_12, div_ne11);
|
||||
uint32_t i11 = r - i13 * ne11_12 - i12 * ne11;
|
||||
|
||||
uint32_t i23 = is_broadcast3 ? 0 : i13;
|
||||
uint32_t i22 = is_broadcast2 ? 0 : i12;
|
||||
uint32_t i21 = is_broadcast1 ? 0 : i11;
|
||||
|
||||
const float * src2_row = (const float *) ((const uint8_t *) src2->data +
|
||||
i21 * src2->nb[1] + i22 * src2->nb[2] + i23 * src2->nb[3]);
|
||||
|
||||
float * dst_ptr = &dst_row[start_col];
|
||||
const float * src2_ptr = &src2_row[start_col];
|
||||
int remaining = end_col - start_col;
|
||||
while (remaining >= 32) {
|
||||
HVX_Vector v_out = hvx_vmemu(dst_ptr);
|
||||
HVX_Vector v_z = hvx_vmemu(src2_ptr);
|
||||
hvx_vmemu(dst_ptr) = hvx_vec_add_f32_f32(v_out, v_z);
|
||||
dst_ptr += 32;
|
||||
src2_ptr += 32;
|
||||
remaining -= 32;
|
||||
}
|
||||
if (remaining > 0) {
|
||||
HVX_Vector v_out = hvx_vmemu(dst_ptr);
|
||||
HVX_Vector v_z = hvx_vmemu(src2_ptr);
|
||||
hvx_vec_store_u(dst_ptr, remaining * sizeof(float), hvx_vec_add_f32_f32(v_out, v_z));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -378,7 +378,7 @@ static inline HVX_VectorPair accum_q8_0_32x2(
|
||||
return Q6_W_vcombine_VV(v_sum1, v_sum0);
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -401,10 +401,14 @@ static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -484,11 +488,19 @@ static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -519,10 +531,14 @@ static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -637,11 +653,19 @@ static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -663,10 +687,14 @@ static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -745,11 +773,19 @@ static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -773,10 +809,14 @@ static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const
|
||||
v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -857,11 +897,19 @@ static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, floa
|
||||
v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1);
|
||||
}
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows, const float * restrict sz) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y_q = vy;
|
||||
|
||||
@@ -896,10 +944,14 @@ static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const
|
||||
|
||||
v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f));
|
||||
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
if (sz) {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float, hvx_vmemu(sz)));
|
||||
} else {
|
||||
hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float);
|
||||
}
|
||||
}
|
||||
|
||||
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) {
|
||||
static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows, const float * restrict sz0, const float * restrict sz1) {
|
||||
const uint8_t * restrict tile_ptr = vx;
|
||||
const uint8_t * restrict y0_q = vy0;
|
||||
const uint8_t * restrict y1_q = vy1;
|
||||
@@ -1013,8 +1065,16 @@ static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, floa
|
||||
v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f));
|
||||
v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f));
|
||||
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
if (sz0) {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vmemu(sz0)));
|
||||
} else {
|
||||
hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0);
|
||||
}
|
||||
if (sz1) {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vmemu(sz1)));
|
||||
} else {
|
||||
hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1);
|
||||
}
|
||||
}
|
||||
|
||||
static inline void quantize_f32_q8_0_tiled_kernel(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "hvx-base.h"
|
||||
#include "hvx-inverse.h"
|
||||
#include "hvx-exp.h"
|
||||
|
||||
#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022
|
||||
#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777
|
||||
@@ -139,4 +140,42 @@ static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restr
|
||||
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_fast_sigmoid_f16(HVX_Vector x_v) {
|
||||
const HVX_Vector v_one = hvx_vec_splat_f16(1.0f);
|
||||
const HVX_Vector v_neg_log2e = hvx_vec_splat_f16(-EXP_LOG2E_F);
|
||||
const HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF);
|
||||
|
||||
// Compute absolute value of x_v
|
||||
HVX_Vector abs_x = Q6_V_vand_VV(x_v, em_mask);
|
||||
|
||||
// Compute u = -abs_x * log2(e) <= 0.
|
||||
HVX_Vector u = hvx_vec_mul_f16_f16(abs_x, v_neg_log2e);
|
||||
|
||||
// Clamp input to prevent underflow in exp2
|
||||
const HVX_Vector v_clamp_min = hvx_vec_splat_f16(-24.0f);
|
||||
u = Q6_Vhf_vmax_VhfVhf(v_clamp_min, u);
|
||||
|
||||
HVX_Vector exp_val = hvx_vec_exp2_f16(u);
|
||||
HVX_Vector denom = hvx_vec_add_f16_f16(v_one, exp_val);
|
||||
HVX_Vector sig_abs = hvx_vec_inverse_f16(denom);
|
||||
|
||||
// check if x_v < 0 (using integer comparison on absolute value)
|
||||
HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(abs_x, x_v);
|
||||
|
||||
// If x_v < 0, return 1.0f - sig_abs
|
||||
HVX_Vector sig_neg = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(v_one, sig_abs));
|
||||
return Q6_V_vmux_QVV(is_neg, sig_neg, sig_abs);
|
||||
}
|
||||
|
||||
static inline HVX_Vector hvx_vec_tanh_f16(HVX_Vector x) {
|
||||
// tanh(x) = 2 * sigmoid(2x) - 1
|
||||
const HVX_Vector v_two = hvx_vec_splat_f16(2.0f);
|
||||
|
||||
HVX_Vector x2 = hvx_vec_mul_f16_f16(x, v_two);
|
||||
HVX_Vector sig2x = hvx_vec_fast_sigmoid_f16(x2);
|
||||
|
||||
const HVX_Vector v_neg_one = hvx_vec_splat_f16(-1.0f);
|
||||
return hvx_vec_add_f16_f16(hvx_vec_mul_f16_f16(sig2x, v_two), v_neg_one);
|
||||
}
|
||||
|
||||
#endif /* HVX_SIGMOID_H */
|
||||
|
||||
@@ -575,6 +575,7 @@ static inline void profile_stop(uint32_t mode, struct profile_data * d) {
|
||||
static int execute_op(struct htp_ops_context * octx) {
|
||||
switch (octx->op) {
|
||||
case HTP_OP_MUL_MAT:
|
||||
case HTP_OP_MUL_MAT_ADD:
|
||||
return op_matmul(octx);
|
||||
|
||||
case HTP_OP_MUL_MAT_ID:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -392,56 +392,49 @@ static inline size_t htp_mm_hvx_get_vtcm_sizes(
|
||||
case HTP_MM_KERNEL_HVX_QUANT_ROW: {
|
||||
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10);
|
||||
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad is also used in dynamic quantizer to store padded src1 rows
|
||||
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
if (vtcm_src0_size < src1_row_size_padded) {
|
||||
vtcm_src0_size = src1_row_size_padded;
|
||||
}
|
||||
|
||||
vtcm_src0_size = vtcm_src0_size * n_threads;
|
||||
vtcm_dst_size = vtcm_dst_size * n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = ne10 / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
vtcm_src0_size = repacked_vtcm_size * n_threads;
|
||||
}
|
||||
|
||||
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
|
||||
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
if (dst_size_per_thread < quant_scratch_size_per_thread) {
|
||||
dst_size_per_thread = quant_scratch_size_per_thread;
|
||||
}
|
||||
vtcm_dst_size = dst_size_per_thread * n_threads;
|
||||
break;
|
||||
}
|
||||
case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: {
|
||||
size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10);
|
||||
|
||||
vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256);
|
||||
|
||||
size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256);
|
||||
if (vtcm_src0_size < src1_row_size_padded) {
|
||||
vtcm_src0_size = src1_row_size_padded;
|
||||
}
|
||||
|
||||
vtcm_src0_size = vtcm_src0_size * n_threads;
|
||||
vtcm_dst_size = vtcm_dst_size * n_threads;
|
||||
|
||||
if (is_repack) {
|
||||
uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
uint32_t n_k_tiles = ne10 / 32;
|
||||
uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
vtcm_src0_size = repacked_vtcm_size * n_threads;
|
||||
}
|
||||
|
||||
size_t quant_scratch_size_per_thread = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float));
|
||||
size_t dst_size_per_thread = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0;
|
||||
if (dst_size_per_thread < quant_scratch_size_per_thread) {
|
||||
dst_size_per_thread = quant_scratch_size_per_thread;
|
||||
}
|
||||
vtcm_dst_size = dst_size_per_thread * n_threads;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -463,7 +456,8 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
|
||||
size_t src0_row_size, // nb01
|
||||
uint32_t n_prefetch,
|
||||
size_t * vtcm_src0_size_out,
|
||||
size_t * vtcm_src1_size_out
|
||||
size_t * vtcm_src1_size_out,
|
||||
size_t * vtcm_dst_size_out
|
||||
) {
|
||||
const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 ||
|
||||
wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL ||
|
||||
@@ -476,29 +470,22 @@ static inline size_t htp_mm_hvx_id_get_vtcm_sizes(
|
||||
size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256);
|
||||
size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256);
|
||||
|
||||
// src0 spad also holds temporary transposed src1 columns during dynamic quantization.
|
||||
const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float));
|
||||
if (src0_sz_per_thread < src1_row_size_padded) {
|
||||
src0_sz_per_thread = src1_row_size_padded;
|
||||
}
|
||||
|
||||
if (is_repack) {
|
||||
const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype);
|
||||
const uint32_t n_k_tiles = ne10 / 32;
|
||||
const uint32_t tile_row_size = n_k_tiles * aligned_tile_size;
|
||||
size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256);
|
||||
if (repacked_vtcm_size < src1_row_size_padded) {
|
||||
repacked_vtcm_size = src1_row_size_padded;
|
||||
}
|
||||
src0_sz_per_thread = repacked_vtcm_size;
|
||||
}
|
||||
|
||||
const size_t vtcm_src0_size = src0_sz_per_thread * n_threads;
|
||||
const size_t vtcm_dst_size = htp_mm_round_up(ne10 * sizeof(float), QK_Q8_0_TILED * sizeof(float)) * n_threads;
|
||||
|
||||
*vtcm_src0_size_out = vtcm_src0_size;
|
||||
*vtcm_src1_size_out = src1_sz;
|
||||
*vtcm_dst_size_out = vtcm_dst_size;
|
||||
|
||||
return vtcm_src0_size + src1_sz;
|
||||
return vtcm_src0_size + src1_sz + vtcm_dst_size;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@@ -31,6 +31,11 @@ if (GGML_OPENCL_EMBED_KERNELS)
|
||||
target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
|
||||
endif ()
|
||||
|
||||
if (GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
|
||||
message(STATUS "OpenCL will use precompiled binary kernels for Adreno (improved performance on some platforms)")
|
||||
add_compile_definitions(GGML_OPENCL_USE_ADRENO_BIN_KERNELS)
|
||||
endif ()
|
||||
|
||||
function(ggml_opencl_add_kernel KNAME)
|
||||
set(KERN_HDR ${CMAKE_CURRENT_BINARY_DIR}/autogenerated/${KNAME}.cl.h)
|
||||
set(KERN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/kernels/${KNAME}.cl)
|
||||
@@ -78,6 +83,8 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mv_f16_f32_l4
|
||||
mul_mv_f16_f32
|
||||
mul_mv_f32_f32
|
||||
mul_mv_q1_0_f32
|
||||
mul_mv_q1_0_f32_flat
|
||||
mul_mv_q4_0_f32
|
||||
mul_mv_q4_0_f32_v
|
||||
mul_mv_q4_0_f32_8x_flat
|
||||
@@ -128,6 +135,7 @@ set(GGML_OPENCL_KERNELS
|
||||
moe_sort_by_expert
|
||||
mul_mm_f32_f32_l4_lm
|
||||
mul_mm_f16_f32_l4_lm
|
||||
mul_mm_q1_0_f32_l4_lm
|
||||
mul_mm_q4_0_f32_l4_lm
|
||||
mul_mm_q4_1_f32_l4_lm
|
||||
mul_mm_q5_0_f32_l4_lm
|
||||
@@ -137,6 +145,8 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mm_q4_k_f32_l4_lm
|
||||
mul_mm_q5_k_f32_l4_lm
|
||||
mul_mm_q6_k_f32_l4_lm
|
||||
gemv_noshuffle_q1_0_f32
|
||||
gemm_noshuffle_q1_0_f32
|
||||
gemv_noshuffle_q4_0_f32
|
||||
gemv_noshuffle_q4_0_f32_spec
|
||||
gemm_noshuffle_q4_0_f32
|
||||
@@ -192,7 +202,10 @@ set(GGML_OPENCL_KERNELS
|
||||
mul_mm_f16_f32_kq_kqv
|
||||
conv2d
|
||||
conv2d_f16_f32
|
||||
flash_attn_pre_f16
|
||||
flash_attn_f32_f16
|
||||
flash_attn_f32_q8_0
|
||||
flash_attn_f32_q4_0
|
||||
flash_attn_f16
|
||||
flash_attn_f32
|
||||
)
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
#pragma once
|
||||
|
||||
// Flash-attention per-(dk,dv) tile tuning for the Adreno OpenCL backend.
|
||||
// Isolated from ggml-opencl.cpp so the tuning numbers are easy to find and
|
||||
// edit; the FA dispatch and kernel-compile logic stay in the main file.
|
||||
// This header is a file section — it is #included exactly once, at the point
|
||||
// in ggml-opencl.cpp where the ggml logging macros are already in scope.
|
||||
|
||||
// Per-(dk, dv) FA config; shared by dispatch and supports_op.
|
||||
struct ggml_opencl_fa_dim {
|
||||
int dk; int dv; int bm; int bn; int n_split; int nkv_split_threshold;
|
||||
};
|
||||
|
||||
// Split variant fires when n_kv >= threshold (threshold=0 -> always split).
|
||||
// Default tuning covers Adreno 7xx/8xx mobile and X1-series laptop GPUs.
|
||||
static const ggml_opencl_fa_dim g_fa_dims_adreno_default[] = {
|
||||
{ 40, 40, 64, 32, 1, 0}, { 64, 64, 64, 32, 2, 64},
|
||||
{ 80, 80, 64, 32, 2, 64}, { 96, 96, 64, 32, 2, 64},
|
||||
{112, 112, 64, 32, 2, 64}, {128, 128, 64, 32, 2, 64},
|
||||
{192, 128, 16, 16, 1, 0},
|
||||
{192, 192, 16, 16, 1, 0},
|
||||
{256, 256, 16, 16, 16, 0},
|
||||
};
|
||||
|
||||
struct ggml_opencl_fa_dim_table {
|
||||
const ggml_opencl_fa_dim * data;
|
||||
size_t count;
|
||||
|
||||
const ggml_opencl_fa_dim * begin() const { return data; }
|
||||
const ggml_opencl_fa_dim * end() const { return data + count; }
|
||||
};
|
||||
|
||||
// Mutable copy of the active table; GGML_OPENCL_FA_TUNE patches entries here
|
||||
// at backend init without touching the const source table.
|
||||
static ggml_opencl_fa_dim g_fa_dims_runtime[
|
||||
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0])];
|
||||
|
||||
static ggml_opencl_fa_dim_table g_opencl_fa_dims = {
|
||||
g_fa_dims_adreno_default,
|
||||
sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]),
|
||||
};
|
||||
|
||||
// GGML_OPENCL_FA_TUNE=dk:dv:bm:bn:nsplit:thr[,…] — patches matching entries
|
||||
// in the active table at backend init, before the first FA kernel compiles.
|
||||
// Unmatched (dk,dv) pairs are warned and ignored.
|
||||
static void ggml_opencl_fa_apply_env_overrides() {
|
||||
const char * e = std::getenv("GGML_OPENCL_FA_TUNE");
|
||||
if (!e || !e[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string s = e;
|
||||
size_t pos = 0;
|
||||
while (pos < s.size()) {
|
||||
size_t comma = s.find(',', pos);
|
||||
std::string entry = s.substr(pos, comma == std::string::npos ? std::string::npos : comma - pos);
|
||||
int dk, dv, bm, bn, nsplit, thr;
|
||||
if (std::sscanf(entry.c_str(), "%d:%d:%d:%d:%d:%d", &dk, &dv, &bm, &bn, &nsplit, &thr) == 6) {
|
||||
bool patched = false;
|
||||
for (size_t i = 0; i < g_opencl_fa_dims.count; ++i) {
|
||||
ggml_opencl_fa_dim & d = g_fa_dims_runtime[i];
|
||||
if (d.dk == dk && d.dv == dv) {
|
||||
d.bm = bm; d.bn = bn; d.n_split = nsplit; d.nkv_split_threshold = thr;
|
||||
GGML_LOG_INFO("ggml_opencl: FA tune override DK=%d DV=%d -> bm=%d bn=%d n_split=%d thr=%d\n",
|
||||
dk, dv, bm, bn, nsplit, thr);
|
||||
patched = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!patched) {
|
||||
GGML_LOG_WARN("ggml_opencl: FA tune override DK=%d DV=%d ignored (no matching dim)\n", dk, dv);
|
||||
}
|
||||
} else {
|
||||
GGML_LOG_WARN("ggml_opencl: FA tune override entry malformed: '%s'\n", entry.c_str());
|
||||
}
|
||||
if (comma == std::string::npos) break;
|
||||
pos = comma + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the default table into the mutable runtime buffer and apply any
|
||||
// GGML_OPENCL_FA_TUNE overrides. A per-generation table can be added here
|
||||
// once it has been tuned on hardware.
|
||||
static void ggml_cl_init_fa_dims_table() {
|
||||
const size_t count = sizeof(g_fa_dims_adreno_default) / sizeof(g_fa_dims_adreno_default[0]);
|
||||
for (size_t i = 0; i < count; ++i) {
|
||||
g_fa_dims_runtime[i] = g_fa_dims_adreno_default[i];
|
||||
}
|
||||
g_opencl_fa_dims = { g_fa_dims_runtime, count };
|
||||
ggml_opencl_fa_apply_env_overrides();
|
||||
}
|
||||
+2709
-254
File diff suppressed because it is too large
Load Diff
@@ -27,6 +27,8 @@
|
||||
#define QR5_1 2
|
||||
#define QK8_0 32
|
||||
#define QR8_0 1
|
||||
#define QK1_0 128
|
||||
#define QR1_0 1
|
||||
#define QK_K 256
|
||||
#define K_SCALE_SIZE (3 * QK_K / 64)
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
@@ -38,6 +40,14 @@ typedef ushort uint16_t;
|
||||
typedef int int32_t;
|
||||
typedef uint uint32_t;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q1_0
|
||||
//------------------------------------------------------------------------------
|
||||
typedef struct {
|
||||
half d; // delta
|
||||
uchar qs[QK1_0/8]; // 1-bit signs (16 bytes)
|
||||
} block_q1_0;
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// block_q4_0
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -159,6 +169,42 @@ kernel void kernel_convert_f16_to_bf16(
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q1_0
|
||||
// Convert block_q1_0 (AOS) to 2 separate arrays (SOA): quant bytes + scales.
|
||||
// q1_0 bits are stored in natural order (bit j of byte i -> weight 8*i + j)
|
||||
//------------------------------------------------------------------------------
|
||||
kernel void kernel_convert_block_q1_0(
|
||||
global block_q1_0 * src0,
|
||||
global uchar * dst_q,
|
||||
global half * dst_d
|
||||
) {
|
||||
global block_q1_0 * b = (global block_q1_0 *) src0 + get_global_id(0);
|
||||
global uchar * q = (global uchar *) dst_q + (QK1_0/8)*get_global_id(0);
|
||||
global half * d = (global half *) dst_d + get_global_id(0);
|
||||
|
||||
*d = b->d;
|
||||
|
||||
for (int i = 0; i < QK1_0/8; ++i) {
|
||||
q[i] = b->qs[i];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q1_0(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
global block_q1_0 * dst
|
||||
) {
|
||||
global block_q1_0 * b = (global block_q1_0 *) dst + get_global_id(0);
|
||||
global uchar * q = (global uchar *) src_q + (QK1_0/8)*get_global_id(0);
|
||||
global half * d = (global half *) src_d + get_global_id(0);
|
||||
|
||||
b->d = *d;
|
||||
for (int i = 0; i < QK1_0/8; ++i) {
|
||||
b->qs[i] = q[i];
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// kernel_convert_block_q4_0
|
||||
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
|
||||
@@ -1582,6 +1628,158 @@ kernel void kernel_restore_block_q8_0(
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q8_0 -> f32 dequant (f32/f32 FA path).
|
||||
kernel void kernel_dequant_q8_0_f32_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global float * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global char * qs = block + 2;
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global float * out = dst + (dst_row_base + blk_i0) * QK8_0;
|
||||
|
||||
for (int i = 0; i < QK8_0; ++i) {
|
||||
out[i] = d * (float)qs[i];
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q8_0 -> f16 dequant. Rows tight, batch strides may be gapped.
|
||||
kernel void kernel_dequant_q8_0_f16_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global half * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK8_0);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global char * qs = block + 2;
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global half * out = dst + (dst_row_base + blk_i0) * QK8_0;
|
||||
|
||||
for (int i = 0; i < QK8_0; ++i) {
|
||||
out[i] = (half)(d * (float)qs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q4_0 -> f32 dequant (mirrors the q8_0 view variant).
|
||||
kernel void kernel_dequant_q4_0_f32_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global float * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global uchar * qs = (global uchar *)(block + 2);
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global float * out = dst + (dst_row_base + blk_i0) * QK4_0;
|
||||
|
||||
for (int i = 0; i < QK4_0/2; ++i) {
|
||||
uchar byte = qs[i];
|
||||
int q0 = (int)(byte & 0x0F) - 8;
|
||||
int q1 = (int)(byte >> 4) - 8;
|
||||
out[i] = d * (float)q0;
|
||||
out[i + QK4_0/2] = d * (float)q1;
|
||||
}
|
||||
}
|
||||
|
||||
// View-aware AoS q4_0 -> f16 dequant (mirrors the q8_0 view variant).
|
||||
kernel void kernel_dequant_q4_0_f16_view_aos(
|
||||
global char * src,
|
||||
ulong src_offset,
|
||||
ulong src_nb1,
|
||||
ulong src_nb2,
|
||||
ulong src_nb3,
|
||||
int nblk0,
|
||||
int ne1,
|
||||
int ne2,
|
||||
int ne3,
|
||||
global half * dst
|
||||
) {
|
||||
int blk_i0 = get_global_id(0);
|
||||
int i1 = get_global_id(1);
|
||||
int batch = get_global_id(2);
|
||||
|
||||
if (blk_i0 >= nblk0) return;
|
||||
if (i1 >= ne1) return;
|
||||
|
||||
int i2 = batch % ne2;
|
||||
int i3 = batch / ne2;
|
||||
if (i3 >= ne3) return;
|
||||
|
||||
global char * block = src + src_offset + (ulong)i3*src_nb3 + (ulong)i2*src_nb2 + (ulong)i1*src_nb1 + (ulong)blk_i0 * (2 + QK4_0/2);
|
||||
float d = vload_half(0, (global half *)block);
|
||||
global uchar * qs = (global uchar *)(block + 2);
|
||||
|
||||
ulong dst_row_base = ((ulong)i3 * ne2 * ne1 + (ulong)i2 * ne1 + (ulong)i1) * nblk0;
|
||||
global half * out = dst + (dst_row_base + blk_i0) * QK4_0;
|
||||
|
||||
for (int i = 0; i < QK4_0/2; ++i) {
|
||||
uchar byte = qs[i];
|
||||
int q0 = (int)(byte & 0x0F) - 8;
|
||||
int q1 = (int)(byte >> 4) - 8;
|
||||
out[i] = (half)(d * (float)q0);
|
||||
out[i + QK4_0/2] = (half)(d * (float)q1);
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_restore_block_q8_0_trans(
|
||||
global uchar * src_q,
|
||||
global half * src_d,
|
||||
|
||||
@@ -4,14 +4,26 @@
|
||||
#define ACC_TYPE4 float4
|
||||
#define DATA_TYPE half
|
||||
#define DATA_TYPE4 half4
|
||||
#define CONVERT_ACC4(x) convert_float4(x)
|
||||
#define CONVERT_DATA4(x) convert_half4(x)
|
||||
#define CONVERT_ACC4(x) ((float4)((float)(x).s0, (float)(x).s1, (float)(x).s2, (float)(x).s3))
|
||||
#define CONVERT_DATA4(x) ((half4)((half)(x).s0, (half)(x).s1, (half)(x).s2, (half)(x).s3))
|
||||
|
||||
#define DK_VEC (DK/4)
|
||||
#define DV_VEC (DV/4)
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#define Q1_WG_SIZE 64
|
||||
|
||||
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
|
||||
// infinite operand can cause undefined behavior and miscompilation for exp.
|
||||
// Therefore, a large negative value is used instead.
|
||||
#define FA_M_INIT (-3.0e38f)
|
||||
|
||||
// Drop full unroll at DK>=192 — Adreno compiler host-memory budget.
|
||||
#if DK >= 192
|
||||
#define FA_UNROLL
|
||||
#else
|
||||
#define FA_UNROLL _Pragma("unroll")
|
||||
#endif
|
||||
|
||||
inline float get_alibi_slope(
|
||||
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
||||
) {
|
||||
@@ -81,18 +93,18 @@ __kernel void flash_attn_f16(
|
||||
if (my_query_row < n_q) {
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
@@ -125,49 +137,72 @@ __kernel void flash_attn_f16(
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
for (int j = 0; j < BLOCK_N; j += 4) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
const int k_row2 = k_start + j + 2;
|
||||
const int k_row3 = k_start + j + 3;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
|
||||
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
|
||||
}
|
||||
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
|
||||
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
||||
const int causal_limit = n_kv - n_q + my_query_row;
|
||||
if (k_row0 > causal_limit) s0 = FA_M_INIT;
|
||||
if (k_row1 > causal_limit) s1 = FA_M_INIT;
|
||||
if (k_row2 > causal_limit) s2 = FA_M_INIT;
|
||||
if (k_row3 > causal_limit) s3 = FA_M_INIT;
|
||||
}
|
||||
|
||||
if (k_row0 >= n_kv) score0 = -INFINITY;
|
||||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||
if (k_row0 >= n_kv) s0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) s1 = FA_M_INIT;
|
||||
if (k_row2 >= n_kv) s2 = FA_M_INIT;
|
||||
if (k_row3 >= n_kv) s3 = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL) {
|
||||
const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
|
||||
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
s0 = logit_softcap * tanh(s0 / logit_softcap);
|
||||
s1 = logit_softcap * tanh(s1 / logit_softcap);
|
||||
s2 = logit_softcap * tanh(s2 / logit_softcap);
|
||||
s3 = logit_softcap * tanh(s3 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
const ACC_TYPE p0 = exp(score0 - m_new);
|
||||
const ACC_TYPE p1 = exp(score1 - m_new);
|
||||
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
||||
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
|
||||
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
|
||||
const ACC_TYPE p0 = native_exp(s0 - m_new);
|
||||
const ACC_TYPE p1 = native_exp(s1 - m_new);
|
||||
const ACC_TYPE p2 = native_exp(s2 - m_new);
|
||||
const ACC_TYPE p3 = native_exp(s3 - m_new);
|
||||
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
|
||||
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
|
||||
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
|
||||
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
|
||||
mad(p0, CONVERT_ACC4(l_v[j][i]),
|
||||
o_acc[i] * scale_prev))));
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1;
|
||||
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
@@ -179,7 +214,7 @@ __kernel void flash_attn_f16(
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
@@ -191,12 +226,12 @@ __kernel void flash_attn_f16(
|
||||
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_i;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
@@ -258,7 +293,7 @@ __kernel void flash_attn_f16_q1(
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
@@ -270,12 +305,12 @@ __kernel void flash_attn_f16_q1(
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -293,7 +328,7 @@ __kernel void flash_attn_f16_q1(
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -301,7 +336,7 @@ __kernel void flash_attn_f16_q1(
|
||||
const ACC_TYPE m_final = local_m[0];
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
@@ -311,7 +346,7 @@ __kernel void flash_attn_f16_q1(
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -325,7 +360,7 @@ __kernel void flash_attn_f16_q1(
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_final);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
@@ -335,7 +370,7 @@ __kernel void flash_attn_f16_q1(
|
||||
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -354,7 +389,7 @@ __kernel void flash_attn_f16_q1(
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
local_o_comp[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -364,7 +399,7 @@ __kernel void flash_attn_f16_q1(
|
||||
}
|
||||
}
|
||||
} else if (tid == 0) {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,18 @@
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#define Q1_WG_SIZE 64
|
||||
|
||||
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
|
||||
// infinite operand can cause undefined behavior and miscompilation for exp.
|
||||
// Therefore, a large negative value is used instead.
|
||||
#define FA_M_INIT (-3.0e38f)
|
||||
|
||||
// Drop full unroll at DK>=192 — Adreno compiler host-memory budget.
|
||||
#if DK >= 192
|
||||
#define FA_UNROLL
|
||||
#else
|
||||
#define FA_UNROLL _Pragma("unroll")
|
||||
#endif
|
||||
|
||||
inline float get_alibi_slope(
|
||||
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
||||
) {
|
||||
@@ -82,18 +94,18 @@ __kernel void flash_attn_f32(
|
||||
if (my_query_row < n_q) {
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
@@ -126,49 +138,72 @@ __kernel void flash_attn_f32(
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
for (int j = 0; j < BLOCK_N; j += 4) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
const int k_row2 = k_start + j + 2;
|
||||
const int k_row3 = k_start + j + 3;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
dot_acc0 = mad(qk, CONVERT_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(qk, CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
dot_acc2 = mad(qk, CONVERT_ACC4(l_k[j+2][k]), dot_acc2);
|
||||
dot_acc3 = mad(qk, CONVERT_ACC4(l_k[j+3][k]), dot_acc3);
|
||||
}
|
||||
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
|
||||
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
||||
const int causal_limit = n_kv - n_q + my_query_row;
|
||||
if (k_row0 > causal_limit) s0 = FA_M_INIT;
|
||||
if (k_row1 > causal_limit) s1 = FA_M_INIT;
|
||||
if (k_row2 > causal_limit) s2 = FA_M_INIT;
|
||||
if (k_row3 > causal_limit) s3 = FA_M_INIT;
|
||||
}
|
||||
|
||||
if (k_row0 >= n_kv) score0 = -INFINITY;
|
||||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||
if (k_row0 >= n_kv) s0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) s1 = FA_M_INIT;
|
||||
if (k_row2 >= n_kv) s2 = FA_M_INIT;
|
||||
if (k_row3 >= n_kv) s3 = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
|
||||
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
s0 = logit_softcap * tanh(s0 / logit_softcap);
|
||||
s1 = logit_softcap * tanh(s1 / logit_softcap);
|
||||
s2 = logit_softcap * tanh(s2 / logit_softcap);
|
||||
s3 = logit_softcap * tanh(s3 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
const ACC_TYPE p0 = exp(score0 - m_new);
|
||||
const ACC_TYPE p1 = exp(score1 - m_new);
|
||||
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
||||
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
|
||||
const ACC_TYPE scale_prev = native_exp(m_i - m_new);
|
||||
const ACC_TYPE p0 = native_exp(s0 - m_new);
|
||||
const ACC_TYPE p1 = native_exp(s1 - m_new);
|
||||
const ACC_TYPE p2 = native_exp(s2 - m_new);
|
||||
const ACC_TYPE p3 = native_exp(s3 - m_new);
|
||||
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
|
||||
o_acc[i] = mad(p3, CONVERT_ACC4(l_v[j+3][i]),
|
||||
mad(p2, CONVERT_ACC4(l_v[j+2][i]),
|
||||
mad(p1, CONVERT_ACC4(l_v[j+1][i]),
|
||||
mad(p0, CONVERT_ACC4(l_v[j][i]),
|
||||
o_acc[i] * scale_prev))));
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1;
|
||||
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
@@ -180,7 +215,7 @@ __kernel void flash_attn_f32(
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
@@ -192,12 +227,12 @@ __kernel void flash_attn_f32(
|
||||
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_i;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
@@ -259,7 +294,7 @@ __kernel void flash_attn_f32_q1(
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
||||
const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_ACC4(q_ptr[i]);
|
||||
}
|
||||
@@ -271,12 +306,12 @@ __kernel void flash_attn_f32_q1(
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -294,7 +329,7 @@ __kernel void flash_attn_f32_q1(
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -302,7 +337,7 @@ __kernel void flash_attn_f32_q1(
|
||||
const ACC_TYPE m_final = local_m[0];
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
@@ -312,7 +347,7 @@ __kernel void flash_attn_f32_q1(
|
||||
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
|
||||
const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
@@ -326,7 +361,7 @@ __kernel void flash_attn_f32_q1(
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_final);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
@@ -336,7 +371,7 @@ __kernel void flash_attn_f32_q1(
|
||||
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -355,7 +390,7 @@ __kernel void flash_attn_f32_q1(
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
local_o_comp[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -365,7 +400,7 @@ __kernel void flash_attn_f32_q1(
|
||||
}
|
||||
}
|
||||
} else if (tid == 0) {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_khr_subgroup_shuffle
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroup_shuffle : enable
|
||||
#define HAS_SUBGROUP_SHUFFLE 1
|
||||
#elif defined(cl_qcom_subgroup_shuffle)
|
||||
#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
|
||||
#define HAS_SUBGROUP_SHUFFLE 1
|
||||
#endif
|
||||
|
||||
#define ACC_TYPE float
|
||||
#define ACC_TYPE4 float4
|
||||
#define Q_DATA_TYPE4 float4
|
||||
@@ -12,9 +20,34 @@
|
||||
|
||||
#define DK_VEC (DK/4)
|
||||
#define DV_VEC (DV/4)
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#define Q1_WG_SIZE 64
|
||||
|
||||
// The kernels are built with -cl-finite-math-only. On some older Adreno GPUs,
|
||||
// infinite operand can cause undefined behavior and miscompilation for exp.
|
||||
// Therefore, a large negative value is used instead.
|
||||
#define FA_M_INIT (-3.0e38f)
|
||||
|
||||
// Drop full unroll at DK>=192 — Adreno compiler host-memory budget.
|
||||
#if DK >= 192
|
||||
#define FA_UNROLL
|
||||
#else
|
||||
#define FA_UNROLL _Pragma("unroll")
|
||||
#endif
|
||||
|
||||
// N_SPLIT>1 splits DK/DV across threads to cut per-thread register use.
|
||||
#ifndef N_SPLIT
|
||||
#define N_SPLIT 1
|
||||
#endif
|
||||
|
||||
#define SPLIT_DK_VEC (DK_VEC / N_SPLIT)
|
||||
#define SPLIT_DV_VEC (DV_VEC / N_SPLIT)
|
||||
|
||||
#if N_SPLIT > 1
|
||||
#define WG_SIZE (BLOCK_M * N_SPLIT)
|
||||
#else
|
||||
#define WG_SIZE (BLOCK_M)
|
||||
#endif
|
||||
|
||||
inline float get_alibi_slope(
|
||||
const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
|
||||
) {
|
||||
@@ -54,19 +87,38 @@ __kernel void flash_attn_f32_f16(
|
||||
const int mask_ne2,
|
||||
const int mask_ne3,
|
||||
const global void* sinks_void,
|
||||
const ulong sinks_offset
|
||||
const ulong sinks_offset,
|
||||
const global void * k_pad_void,
|
||||
const global void * v_pad_void,
|
||||
const global void * mask_pad_void,
|
||||
const global char * blk,
|
||||
const int n_kv_blocks,
|
||||
const ulong mask_pad_nb1,
|
||||
const ulong mask_pad_nb2,
|
||||
const ulong mask_pad_nb3
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int block_q_idx = get_group_id(0);
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
|
||||
const int my_query_row = block_q_idx * BLOCK_M + tid;
|
||||
#if N_SPLIT > 1
|
||||
const int q_lane = tid / N_SPLIT;
|
||||
const int split_idx = tid % N_SPLIT;
|
||||
#else
|
||||
const int q_lane = tid;
|
||||
const int split_idx = 0;
|
||||
#endif
|
||||
|
||||
const int my_query_row = block_q_idx * BLOCK_M + q_lane;
|
||||
const int query_valid = my_query_row < n_q;
|
||||
|
||||
const int batch_idx = head_batch_idx / n_head;
|
||||
const int head_idx = head_batch_idx % n_head;
|
||||
|
||||
const int gqa_ratio = n_head / n_head_kv;
|
||||
const int head_kv_idx = head_idx / gqa_ratio;
|
||||
const int mask_head_idx = mask_void != NULL ? head_idx % mask_ne2 : 0;
|
||||
const int mask_batch_idx = mask_void != NULL ? batch_idx % mask_ne3 : 0;
|
||||
|
||||
const global char* q_base = (const global char*)q_void + q_offset;
|
||||
const global char* k_base = (const global char*)k_void + k_offset;
|
||||
@@ -75,27 +127,41 @@ __kernel void flash_attn_f32_f16(
|
||||
|
||||
const global char* mask_base = NULL;
|
||||
if (mask_void != NULL) {
|
||||
const int mask_head_idx = head_idx % mask_ne2;
|
||||
const int mask_batch_idx = batch_idx % mask_ne3;
|
||||
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
||||
}
|
||||
const global char* mask_pad_base = NULL;
|
||||
if (mask_pad_void != NULL) {
|
||||
mask_pad_base = (const global char*)mask_pad_void + mask_batch_idx * mask_pad_nb3 + mask_head_idx * mask_pad_nb2;
|
||||
}
|
||||
const global char* blk_base = NULL;
|
||||
if (blk != NULL) {
|
||||
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
|
||||
blk_base = blk + (((mask_batch_idx * mask_ne2) + mask_head_idx) * n_q_blocks + block_q_idx) * n_kv_blocks;
|
||||
}
|
||||
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
if (my_query_row < n_q) {
|
||||
ACC_TYPE4 q_priv[SPLIT_DK_VEC];
|
||||
const int dk_off = split_idx * SPLIT_DK_VEC;
|
||||
if (query_valid) {
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
|
||||
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_Q_ACC4(q_ptr[dk_off + i]);
|
||||
}
|
||||
} else {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DK_VEC; ++i) {
|
||||
q_priv[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
ACC_TYPE4 o_acc[SPLIT_DV_VEC];
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
}
|
||||
ACC_TYPE m_i = -INFINITY;
|
||||
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
@@ -103,86 +169,369 @@ __kernel void flash_attn_f32_f16(
|
||||
__local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
|
||||
__local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
|
||||
|
||||
#if N_SPLIT > 1 && !defined(HAS_SUBGROUP_SHUFFLE)
|
||||
__local ACC_TYPE local_partial[BLOCK_N][WG_SIZE];
|
||||
__local ACC_TYPE local_p[BLOCK_M][BLOCK_N];
|
||||
__local ACC_TYPE local_softmax_scale[BLOCK_M];
|
||||
__local ACC_TYPE local_l_inv[BLOCK_M];
|
||||
#endif
|
||||
|
||||
for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
|
||||
char blk_cur = 1;
|
||||
if (blk_base != NULL) {
|
||||
blk_cur = blk_base[k_start / BLOCK_N];
|
||||
if (blk_cur == 0) continue;
|
||||
}
|
||||
|
||||
const int use_kv_pad = k_pad_void != NULL && k_start + BLOCK_N > n_kv;
|
||||
const int k_tile_start = use_kv_pad ? 0 : k_start;
|
||||
const ulong k_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * k_nb1 : k_nb2;
|
||||
const ulong k_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * k_tile_nb2 : k_nb3;
|
||||
const ulong v_tile_nb2 = use_kv_pad ? (ulong) BLOCK_N * v_nb1 : v_nb2;
|
||||
const ulong v_tile_nb3 = use_kv_pad ? (ulong) n_head_kv * v_tile_nb2 : v_nb3;
|
||||
const global char* k_tile_base = use_kv_pad ? (const global char*) k_pad_void : k_base;
|
||||
const global char* v_tile_base = use_kv_pad ? (const global char*) v_pad_void : v_base;
|
||||
|
||||
for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
|
||||
const int row = i / DK_VEC;
|
||||
const int col = i % DK_VEC;
|
||||
const int k_row_idx = k_start + row;
|
||||
if (k_row_idx < n_kv) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
|
||||
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
|
||||
const int k_row_idx = k_tile_start + row;
|
||||
if (use_kv_pad || k_row_idx < n_kv) {
|
||||
const ulong k_row_offset = batch_idx * k_tile_nb3 + head_kv_idx * k_tile_nb2 + k_row_idx * k_nb1;
|
||||
l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_tile_base + k_row_offset))[col];
|
||||
} else {
|
||||
l_k[row][col] = (KV_DATA_TYPE4)(0.0h);
|
||||
}
|
||||
}
|
||||
for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
|
||||
const int row = i / DV_VEC;
|
||||
const int col = i % DV_VEC;
|
||||
const int v_row_idx = k_start + row;
|
||||
if (v_row_idx < n_kv) {
|
||||
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
|
||||
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
|
||||
const int v_row_idx = k_tile_start + row;
|
||||
if (use_kv_pad || v_row_idx < n_kv) {
|
||||
const ulong v_row_offset = batch_idx * v_tile_nb3 + head_kv_idx * v_tile_nb2 + v_row_idx * v_nb1;
|
||||
l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_tile_base + v_row_offset))[col];
|
||||
} else {
|
||||
l_v[row][col] = (KV_DATA_TYPE4)(0.0h);
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (my_query_row >= n_q) {
|
||||
continue;
|
||||
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
|
||||
{
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
|
||||
ACC_TYPE partial0 = 0.0f;
|
||||
ACC_TYPE partial1 = 0.0f;
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < SPLIT_DK_VEC; k++) {
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
ACC_TYPE4 dot0 = qk * CONVERT_KV_ACC4(l_k[j ][dk_off + k]);
|
||||
ACC_TYPE4 dot1 = qk * CONVERT_KV_ACC4(l_k[j+1][dk_off + k]);
|
||||
partial0 += dot0.s0 + dot0.s1 + dot0.s2 + dot0.s3;
|
||||
partial1 += dot1.s0 + dot1.s1 + dot1.s2 + dot1.s3;
|
||||
}
|
||||
|
||||
FA_UNROLL
|
||||
for (int step = 1; step < N_SPLIT; step <<= 1) {
|
||||
partial0 += sub_group_shuffle_xor(partial0, step);
|
||||
partial1 += sub_group_shuffle_xor(partial1, step);
|
||||
}
|
||||
|
||||
ACC_TYPE score0 = partial0 * scale;
|
||||
ACC_TYPE score1 = partial1 * scale;
|
||||
|
||||
if (!query_valid) { score0 = FA_M_INIT; score1 = FA_M_INIT; }
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = FA_M_INIT;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = FA_M_INIT;
|
||||
}
|
||||
if (k_row0 >= n_kv) score0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) score1 = FA_M_INIT;
|
||||
|
||||
if (query_valid && mask_base != NULL && blk_cur != 2) {
|
||||
if (use_kv_pad && mask_pad_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
|
||||
score0 += slope * (ACC_TYPE)mask_ptr[j];
|
||||
score1 += slope * (ACC_TYPE)mask_ptr[j + 1];
|
||||
} else {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
|
||||
// far negative so the tile contributes 0, not exp(0)=1.
|
||||
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
|
||||
const ACC_TYPE sp = native_exp(m_i - m_exp);
|
||||
const ACC_TYPE p0 = native_exp(score0 - m_exp);
|
||||
const ACC_TYPE p1 = native_exp(score1 - m_exp);
|
||||
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * sp
|
||||
+ p0 * CONVERT_KV_ACC4(l_v[j ][dv_off + i])
|
||||
+ p1 * CONVERT_KV_ACC4(l_v[j+1][dv_off + i]);
|
||||
}
|
||||
l_i = l_i * sp + p0 + p1;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
|
||||
for (int j = 0; j < BLOCK_N; j += 2) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
#elif N_SPLIT > 1
|
||||
// N_SPLIT>1 fallback (no shuffle): 3-phase local-memory reduction.
|
||||
// Phase 1 — partial dots for all BLOCK_N tokens.
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < SPLIT_DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][dk_off + k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
|
||||
if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
|
||||
}
|
||||
|
||||
if (k_row0 >= n_kv) score0 = -INFINITY;
|
||||
if (k_row1 >= n_kv) score1 = -INFINITY;
|
||||
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score0 = logit_softcap * tanh(score0 / logit_softcap);
|
||||
score1 = logit_softcap * tanh(score1 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(score0, score1));
|
||||
const ACC_TYPE p0 = exp(score0 - m_new);
|
||||
const ACC_TYPE p1 = exp(score1 - m_new);
|
||||
const ACC_TYPE scale_prev = exp(m_i - m_new);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1;
|
||||
m_i = m_new;
|
||||
local_partial[j][tid] =
|
||||
dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3;
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE); // 1 barrier: partial dots visible
|
||||
|
||||
// Phase 2 — split_idx==0 reduces partial sums and computes block softmax.
|
||||
if (split_idx == 0) {
|
||||
if (query_valid) {
|
||||
ACC_TYPE m_new = m_i;
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
const int k_row = k_start + j;
|
||||
ACC_TYPE score = 0.0f;
|
||||
FA_UNROLL
|
||||
for (int s = 0; s < N_SPLIT; s++) {
|
||||
score += local_partial[j][q_lane * N_SPLIT + s];
|
||||
}
|
||||
score *= scale;
|
||||
|
||||
if (is_causal && k_row > (n_kv - n_q + my_query_row)) score = FA_M_INIT;
|
||||
if (k_row >= n_kv) score = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL && blk_cur != 2) {
|
||||
if (use_kv_pad && mask_pad_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
|
||||
score += slope * (ACC_TYPE)mask_ptr[j];
|
||||
} else {
|
||||
const global MASK_DATA_TYPE* mask_ptr =
|
||||
(const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row < n_kv) score += slope * (ACC_TYPE)mask_ptr[k_row];
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
score = logit_softcap * tanh(score / logit_softcap);
|
||||
}
|
||||
|
||||
m_new = max(m_new, score);
|
||||
local_p[q_lane][j] = score;
|
||||
}
|
||||
|
||||
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
|
||||
const ACC_TYPE sp = native_exp(m_i - m_exp);
|
||||
ACC_TYPE l_new = l_i * sp;
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
const ACC_TYPE p = native_exp(local_p[q_lane][j] - m_exp);
|
||||
local_p[q_lane][j] = p;
|
||||
l_new += p;
|
||||
}
|
||||
local_softmax_scale[q_lane] = sp;
|
||||
l_i = l_new;
|
||||
m_i = m_new;
|
||||
} else {
|
||||
local_softmax_scale[q_lane] = 1.0f;
|
||||
for (int j = 0; j < BLOCK_N; ++j) local_p[q_lane][j] = 0.0f;
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
// Phase 3 — V accumulate using broadcast probabilities.
|
||||
{
|
||||
const ACC_TYPE sp_block = local_softmax_scale[q_lane];
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] *= sp_block;
|
||||
}
|
||||
for (int j = 0; j < BLOCK_N; ++j) {
|
||||
const ACC_TYPE p = local_p[q_lane][j];
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_acc[i] = mad(p, CONVERT_KV_ACC4(l_v[j][dv_off + i]), o_acc[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// N_SPLIT==1: j+=4 unroll. Requires BLOCK_N % 4 == 0.
|
||||
if (query_valid) {
|
||||
for (int j = 0; j < BLOCK_N; j += 4) {
|
||||
const int k_row0 = k_start + j;
|
||||
const int k_row1 = k_start + j + 1;
|
||||
const int k_row2 = k_start + j + 2;
|
||||
const int k_row3 = k_start + j + 3;
|
||||
|
||||
ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc2 = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE4 dot_acc3 = (ACC_TYPE4)(0.0f);
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
const ACC_TYPE4 qk = q_priv[k];
|
||||
dot_acc0 = mad(qk, CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
|
||||
dot_acc1 = mad(qk, CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
|
||||
dot_acc2 = mad(qk, CONVERT_KV_ACC4(l_k[j+2][k]), dot_acc2);
|
||||
dot_acc3 = mad(qk, CONVERT_KV_ACC4(l_k[j+3][k]), dot_acc3);
|
||||
}
|
||||
ACC_TYPE s0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
|
||||
ACC_TYPE s1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
|
||||
ACC_TYPE s2 = (dot_acc2.s0 + dot_acc2.s1 + dot_acc2.s2 + dot_acc2.s3) * scale;
|
||||
ACC_TYPE s3 = (dot_acc3.s0 + dot_acc3.s1 + dot_acc3.s2 + dot_acc3.s3) * scale;
|
||||
|
||||
if (is_causal) {
|
||||
const int causal_limit = n_kv - n_q + my_query_row;
|
||||
if (k_row0 > causal_limit) s0 = FA_M_INIT;
|
||||
if (k_row1 > causal_limit) s1 = FA_M_INIT;
|
||||
if (k_row2 > causal_limit) s2 = FA_M_INIT;
|
||||
if (k_row3 > causal_limit) s3 = FA_M_INIT;
|
||||
}
|
||||
if (k_row0 >= n_kv) s0 = FA_M_INIT;
|
||||
if (k_row1 >= n_kv) s1 = FA_M_INIT;
|
||||
if (k_row2 >= n_kv) s2 = FA_M_INIT;
|
||||
if (k_row3 >= n_kv) s3 = FA_M_INIT;
|
||||
|
||||
if (mask_base != NULL && blk_cur != 2) {
|
||||
if (use_kv_pad && mask_pad_base != NULL) {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_pad_base + my_query_row * mask_pad_nb1);
|
||||
s0 += slope * (ACC_TYPE)mask_ptr[j];
|
||||
s1 += slope * (ACC_TYPE)mask_ptr[j + 1];
|
||||
s2 += slope * (ACC_TYPE)mask_ptr[j + 2];
|
||||
s3 += slope * (ACC_TYPE)mask_ptr[j + 3];
|
||||
} else {
|
||||
const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
|
||||
if (k_row0 < n_kv) s0 += slope * (ACC_TYPE)mask_ptr[k_row0];
|
||||
if (k_row1 < n_kv) s1 += slope * (ACC_TYPE)mask_ptr[k_row1];
|
||||
if (k_row2 < n_kv) s2 += slope * (ACC_TYPE)mask_ptr[k_row2];
|
||||
if (k_row3 < n_kv) s3 += slope * (ACC_TYPE)mask_ptr[k_row3];
|
||||
}
|
||||
}
|
||||
|
||||
if (logit_softcap > 0.0f) {
|
||||
s0 = logit_softcap * tanh(s0 / logit_softcap);
|
||||
s1 = logit_softcap * tanh(s1 / logit_softcap);
|
||||
s2 = logit_softcap * tanh(s2 / logit_softcap);
|
||||
s3 = logit_softcap * tanh(s3 / logit_softcap);
|
||||
}
|
||||
|
||||
const ACC_TYPE m_new = max(m_i, max(max(s0, s1), max(s2, s3)));
|
||||
// Whole tile masked (m_new == FA_M_INIT): force the exp() args
|
||||
// far negative so the tile contributes 0, not exp(0)=1.
|
||||
const ACC_TYPE m_exp = (m_new == FA_M_INIT) ? 0.0f : m_new;
|
||||
const ACC_TYPE scale_prev = native_exp(m_i - m_exp);
|
||||
const ACC_TYPE p0 = native_exp(s0 - m_exp);
|
||||
const ACC_TYPE p1 = native_exp(s1 - m_exp);
|
||||
const ACC_TYPE p2 = native_exp(s2 - m_exp);
|
||||
const ACC_TYPE p3 = native_exp(s3 - m_exp);
|
||||
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = mad(p3, CONVERT_KV_ACC4(l_v[j+3][i]),
|
||||
mad(p2, CONVERT_KV_ACC4(l_v[j+2][i]),
|
||||
mad(p1, CONVERT_KV_ACC4(l_v[j+1][i]),
|
||||
mad(p0, CONVERT_KV_ACC4(l_v[j][i]),
|
||||
o_acc[i] * scale_prev))));
|
||||
}
|
||||
l_i = l_i * scale_prev + p0 + p1 + p2 + p3;
|
||||
m_i = m_new;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// End of tile: every thread must finish reading l_k/l_v before the
|
||||
// next iteration's load overwrites them (WAR hazard on local memory).
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
if (my_query_row < n_q) {
|
||||
// Write output.
|
||||
#if N_SPLIT > 1 && defined(HAS_SUBGROUP_SHUFFLE)
|
||||
if (query_valid) {
|
||||
ACC_TYPE sinks_sp = 1.0f;
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
sinks_sp = exp(m_i - m_final);
|
||||
l_i = l_i * sinks_sp + exp(m_sink - m_final);
|
||||
m_i = m_final;
|
||||
}
|
||||
const ACC_TYPE l_inv = (l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_inv > 0.0f) {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
|
||||
}
|
||||
} else {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif N_SPLIT > 1
|
||||
if (split_idx == 0) {
|
||||
ACC_TYPE sinks_sp = 1.0f;
|
||||
if (query_valid && sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
sinks_sp = exp(m_i - m_final);
|
||||
l_i = l_i * sinks_sp + exp(m_sink - m_final);
|
||||
m_i = m_final;
|
||||
}
|
||||
local_softmax_scale[q_lane] = sinks_sp;
|
||||
local_l_inv[q_lane] = (query_valid && l_i > 0.0f) ? (1.0f / l_i) : 0.0f;
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
if (query_valid) {
|
||||
const ACC_TYPE sinks_sp = local_softmax_scale[q_lane];
|
||||
const ACC_TYPE l_inv = local_l_inv[q_lane];
|
||||
const int dv_off = split_idx * SPLIT_DV_VEC;
|
||||
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_inv > 0.0f) {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = CONVERT_O_DATA4(o_acc[i] * sinks_sp * l_inv);
|
||||
}
|
||||
} else {
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < SPLIT_DV_VEC; ++i) {
|
||||
o_row[dv_off + i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (query_valid) {
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
const ACC_TYPE m_sink = sinks_ptr[head_idx];
|
||||
const ACC_TYPE m_final = max(m_i, m_sink);
|
||||
|
||||
const ACC_TYPE scale_o = exp(m_i - m_final);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] *= scale_o;
|
||||
}
|
||||
@@ -194,17 +543,18 @@ __kernel void flash_attn_f32_f16(
|
||||
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
|
||||
if (l_i > 0.0f) {
|
||||
const ACC_TYPE l_inv = 1.0f / l_i;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_row[i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__kernel void flash_attn_f32_f16_q1(
|
||||
@@ -258,13 +608,16 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
|
||||
}
|
||||
|
||||
ACC_TYPE4 q_priv[DK_VEC];
|
||||
// Q is uniform across WG threads (n_q=1). Share via local memory to
|
||||
// avoid per-thread q_priv[DK_VEC] dynamic-indexed private array that
|
||||
// spills to DDR on Adreno.
|
||||
__local ACC_TYPE4 q_shared[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
|
||||
const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DK_VEC; ++i) {
|
||||
q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
|
||||
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
|
||||
@@ -273,14 +626,14 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
|
||||
}
|
||||
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
|
||||
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : FA_M_INIT;
|
||||
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
@@ -296,7 +649,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -304,7 +657,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
const ACC_TYPE m_final = local_m[0];
|
||||
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
@@ -314,9 +667,9 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
|
||||
const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int k = 0; k < DK_VEC; k++) {
|
||||
dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
@@ -328,7 +681,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_final);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
@@ -338,7 +691,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
__local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -357,7 +710,7 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
for (int i = 0; i < DV_VEC; i++) {
|
||||
local_o_comp[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
@@ -367,7 +720,257 @@ __kernel void flash_attn_f32_f16_q1(
|
||||
}
|
||||
}
|
||||
} else if (tid == 0) {
|
||||
#pragma unroll
|
||||
FA_UNROLL
|
||||
for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// Flash-decoding split pass. gid(2) = q_idx * n_splits + split_idx.
|
||||
// Partial record per split: [m, l, O[DV]]. Merge kernel applies sink + norm.
|
||||
#define FA_PARTIAL_FLOATS (2 + DV)
|
||||
|
||||
__kernel void flash_attn_f32_f16_q1_split(
|
||||
const global void * q_void, ulong q_offset,
|
||||
const global void * k_void, ulong k_offset,
|
||||
const global void * v_void, ulong v_offset,
|
||||
const float scale,
|
||||
const int n_q,
|
||||
const int n_kv,
|
||||
const int n_head,
|
||||
const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
|
||||
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
||||
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
|
||||
const float max_bias,
|
||||
const float m0,
|
||||
const float m1,
|
||||
const int n_head_log2,
|
||||
const float logit_softcap,
|
||||
const int n_head_kv,
|
||||
const global void * mask_void,
|
||||
const ulong mask_offset,
|
||||
const ulong mask_nb1,
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3,
|
||||
global float * partial_void,
|
||||
const int n_splits,
|
||||
const int kv_per_split
|
||||
) {
|
||||
const int tid = get_local_id(0);
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
const int split_q_idx = get_global_id(2);
|
||||
const int split_idx = split_q_idx % n_splits;
|
||||
const int q_idx = split_q_idx / n_splits;
|
||||
const int batch_idx = head_batch_idx / n_head;
|
||||
const int head_idx = head_batch_idx % n_head;
|
||||
const int gqa_ratio = n_head / n_head_kv;
|
||||
const int head_kv_idx = head_idx / gqa_ratio;
|
||||
|
||||
const int kv_start = split_idx * kv_per_split;
|
||||
const int kv_end = min(kv_start + kv_per_split, n_kv);
|
||||
|
||||
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
|
||||
const ulong record_idx = ((((ulong) batch_idx * n_head + head_idx) * n_q + q_idx)
|
||||
* n_splits + split_idx);
|
||||
global float * rec = partial_void + record_idx * record_stride;
|
||||
global float4 * rec_o = (global float4 *) (rec + 2);
|
||||
|
||||
if (kv_start >= kv_end) {
|
||||
// Empty split: leave sentinel partial for merge.
|
||||
if (tid == 0) {
|
||||
rec[0] = FA_M_INIT;
|
||||
rec[1] = 0.0f;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const global char * q_base = (const global char *) q_void + q_offset;
|
||||
const global char * k_base = (const global char *) k_void + k_offset;
|
||||
const global char * v_base = (const global char *) v_void + v_offset;
|
||||
|
||||
const global char * mask_base = NULL;
|
||||
if (mask_void != NULL) {
|
||||
const int mask_head_idx = head_idx % mask_ne2;
|
||||
const int mask_batch_idx = batch_idx % mask_ne3;
|
||||
mask_base = (const global char *) mask_void + mask_offset +
|
||||
mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2 +
|
||||
(ulong) q_idx * mask_nb1;
|
||||
}
|
||||
|
||||
// Share Q via local memory (n_q=1 per split -> uniform across WG).
|
||||
__local ACC_TYPE4 q_shared[DK_VEC];
|
||||
const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + (ulong) q_idx * q_nb1;
|
||||
const global Q_DATA_TYPE4 * q_ptr = (const global Q_DATA_TYPE4 *) (q_base + q_row_offset);
|
||||
for (int i = tid; i < DK_VEC; i += Q1_WG_SIZE) {
|
||||
q_shared[i] = CONVERT_Q_ACC4(q_ptr[i]);
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
|
||||
|
||||
// Pass 1a — split-local max.
|
||||
ACC_TYPE m_i = FA_M_INIT;
|
||||
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < DK_VEC; ++k) {
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
|
||||
score += slope * (ACC_TYPE) mask_ptr[k_idx];
|
||||
}
|
||||
if (logit_softcap > 0.0f) {
|
||||
score = logit_softcap * tanh(score / logit_softcap);
|
||||
}
|
||||
m_i = max(m_i, score);
|
||||
}
|
||||
|
||||
__local ACC_TYPE local_m[Q1_WG_SIZE];
|
||||
local_m[tid] = m_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
const ACC_TYPE m_c = local_m[0];
|
||||
|
||||
// Pass 1b — softmax-weighted V accumulate.
|
||||
ACC_TYPE4 o_acc[DV_VEC];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
|
||||
ACC_TYPE l_i = 0.0f;
|
||||
|
||||
for (int k_idx = kv_start + tid; k_idx < kv_end; k_idx += Q1_WG_SIZE) {
|
||||
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
|
||||
const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
|
||||
const global KV_DATA_TYPE4 * k_ptr = (const global KV_DATA_TYPE4 *) (k_base + k_row_offset);
|
||||
const global KV_DATA_TYPE4 * v_ptr = (const global KV_DATA_TYPE4 *) (v_base + v_row_offset);
|
||||
ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < DK_VEC; ++k) {
|
||||
dot_acc = mad(q_shared[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
|
||||
}
|
||||
ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
|
||||
if (mask_base != NULL) {
|
||||
const global MASK_DATA_TYPE * mask_ptr = (const global MASK_DATA_TYPE *) (mask_base);
|
||||
score += slope * (ACC_TYPE) mask_ptr[k_idx];
|
||||
}
|
||||
if (logit_softcap > 0.0f) {
|
||||
score = logit_softcap * tanh(score / logit_softcap);
|
||||
}
|
||||
const ACC_TYPE p = exp(score - m_c);
|
||||
l_i += p;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
|
||||
}
|
||||
}
|
||||
|
||||
__local ACC_TYPE local_l[Q1_WG_SIZE];
|
||||
__local ACC_TYPE4 local_o[Q1_WG_SIZE];
|
||||
local_l[tid] = l_i;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_l[tid] += local_l[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
const ACC_TYPE l_c = local_l[0];
|
||||
|
||||
if (tid == 0) {
|
||||
rec[0] = (float) m_c;
|
||||
rec[1] = (float) l_c;
|
||||
}
|
||||
for (int i = 0; i < DV_VEC; ++i) {
|
||||
local_o[tid] = o_acc[i];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
#pragma unroll
|
||||
for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) local_o[tid] += local_o[tid + s];
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
if (tid == 0) {
|
||||
rec_o[i] = local_o[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FD Pass 2: merge per-split partials into final O. Empty splits drop via exp(-INF)=0.
|
||||
__kernel void flash_attn_f32_merge(
|
||||
const global float * partial_void,
|
||||
global void * o_void,
|
||||
const ulong o_offset,
|
||||
const int n_head,
|
||||
const int n_splits,
|
||||
const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
|
||||
const global void * sinks_void,
|
||||
const ulong sinks_offset,
|
||||
const int n_q
|
||||
) {
|
||||
const int lane = get_local_id(0); // 0..DV_VEC-1
|
||||
const int head_batch_idx = get_global_id(1);
|
||||
const int q_idx = get_global_id(2);
|
||||
const int batch_idx = head_batch_idx / n_head;
|
||||
const int head_idx = head_batch_idx % n_head;
|
||||
|
||||
const ulong record_stride = (ulong) FA_PARTIAL_FLOATS;
|
||||
const ulong record_idx_0 = (((ulong) batch_idx * n_head + head_idx) * n_q + q_idx) * n_splits;
|
||||
const global float * rec0 = partial_void + record_idx_0 * record_stride;
|
||||
|
||||
__local ACC_TYPE m_final_shared;
|
||||
__local ACC_TYPE l_final_shared;
|
||||
if (lane == 0) {
|
||||
ACC_TYPE m = FA_M_INIT;
|
||||
for (int c = 0; c < n_splits; ++c) {
|
||||
const ACC_TYPE m_c = rec0[c * record_stride + 0];
|
||||
m = max(m, m_c);
|
||||
}
|
||||
ACC_TYPE m_sink = 0.0f;
|
||||
bool has_sink = false;
|
||||
if (sinks_void != NULL) {
|
||||
const global ACC_TYPE * sinks_ptr =
|
||||
(const global ACC_TYPE *) ((const global char *) sinks_void + sinks_offset);
|
||||
m_sink = sinks_ptr[head_idx];
|
||||
has_sink = true;
|
||||
m = max(m, m_sink);
|
||||
}
|
||||
ACC_TYPE l = 0.0f;
|
||||
for (int c = 0; c < n_splits; ++c) {
|
||||
const ACC_TYPE m_c = rec0[c * record_stride + 0];
|
||||
const ACC_TYPE l_c = rec0[c * record_stride + 1];
|
||||
if (m_c > FA_M_INIT) {
|
||||
l += l_c * exp(m_c - m);
|
||||
}
|
||||
}
|
||||
if (has_sink) {
|
||||
l += exp(m_sink - m);
|
||||
}
|
||||
m_final_shared = m;
|
||||
l_final_shared = l;
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
const ACC_TYPE m_final = m_final_shared;
|
||||
const ACC_TYPE l_final = l_final_shared;
|
||||
const ACC_TYPE l_inv = (l_final > 0.0f) ? (1.0f / l_final) : 0.0f;
|
||||
|
||||
ACC_TYPE4 o = (ACC_TYPE4)(0.0f);
|
||||
for (int c = 0; c < n_splits; ++c) {
|
||||
const global float * rec_c = rec0 + c * record_stride;
|
||||
const ACC_TYPE m_c = rec_c[0];
|
||||
if (m_c <= FA_M_INIT) continue;
|
||||
const global float4 * rec_oc = (const global float4 *) (rec_c + 2);
|
||||
const ACC_TYPE scale_c = exp(m_c - m_final);
|
||||
o = mad((ACC_TYPE4)(scale_c), rec_oc[lane], o);
|
||||
}
|
||||
o = o * l_inv;
|
||||
|
||||
const ulong o_row_offset = (ulong) batch_idx * o_nb3 + (ulong) q_idx * o_nb2 + (ulong) head_idx * o_nb1;
|
||||
global O_DATA_TYPE4 * o_row = (global O_DATA_TYPE4 *) ((global char *) o_void + o_offset + o_row_offset);
|
||||
o_row[lane] = CONVERT_O_DATA4(o);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,156 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
__kernel void flash_attn_kv_pad_f16(
|
||||
const global void * k_void, ulong k_offset,
|
||||
const global void * v_void, ulong v_offset,
|
||||
global void * k_pad_void,
|
||||
global void * v_pad_void,
|
||||
const int n_kv,
|
||||
const int n_head_kv,
|
||||
const int n_batch,
|
||||
const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
|
||||
const ulong v_nb1, const ulong v_nb2, const ulong v_nb3
|
||||
) {
|
||||
const int row_idx = get_global_id(0);
|
||||
const int head_kv_idx = get_global_id(1);
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
if (row_idx >= BLOCK_N || head_kv_idx >= n_head_kv || batch_idx >= n_batch) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tail_start = n_kv - (n_kv % BLOCK_N);
|
||||
const int src_row_idx = tail_start + row_idx;
|
||||
|
||||
const global char * k_src = (const global char *) k_void + k_offset;
|
||||
const global char * v_src = (const global char *) v_void + v_offset;
|
||||
global char * k_pad = (global char *) k_pad_void;
|
||||
global char * v_pad = (global char *) v_pad_void;
|
||||
|
||||
const ulong k_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * k_nb1) + (ulong) row_idx * k_nb1;
|
||||
const ulong v_dst_offset = ((ulong) batch_idx * (ulong) n_head_kv + (ulong) head_kv_idx) * ((ulong) BLOCK_N * v_nb1) + (ulong) row_idx * v_nb1;
|
||||
|
||||
if (src_row_idx < n_kv) {
|
||||
const ulong k_src_offset = (ulong) batch_idx * k_nb3 + (ulong) head_kv_idx * k_nb2 + (ulong) src_row_idx * k_nb1;
|
||||
const ulong v_src_offset = (ulong) batch_idx * v_nb3 + (ulong) head_kv_idx * v_nb2 + (ulong) src_row_idx * v_nb1;
|
||||
|
||||
for (ulong i = 0; i < k_nb1; ++i) {
|
||||
k_pad[k_dst_offset + i] = k_src[k_src_offset + i];
|
||||
}
|
||||
for (ulong i = 0; i < v_nb1; ++i) {
|
||||
v_pad[v_dst_offset + i] = v_src[v_src_offset + i];
|
||||
}
|
||||
} else {
|
||||
for (ulong i = 0; i < k_nb1; ++i) {
|
||||
k_pad[k_dst_offset + i] = 0;
|
||||
}
|
||||
for (ulong i = 0; i < v_nb1; ++i) {
|
||||
v_pad[v_dst_offset + i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__kernel void flash_attn_mask_pad_f16(
|
||||
const global void * mask_void, ulong mask_offset,
|
||||
global void * mask_pad_void,
|
||||
const int n_q,
|
||||
const int n_kv,
|
||||
const ulong mask_nb1,
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
) {
|
||||
const int col_idx = get_global_id(0);
|
||||
const int q_row = get_global_id(1);
|
||||
const int mask_slice = get_global_id(2);
|
||||
|
||||
if (col_idx >= BLOCK_N || q_row >= n_q || mask_slice >= mask_ne2 * mask_ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int tail_start = n_kv - (n_kv % BLOCK_N);
|
||||
const int src_col_idx = tail_start + col_idx;
|
||||
const int mask_head_idx = mask_slice % mask_ne2;
|
||||
const int mask_batch_idx = mask_slice / mask_ne2;
|
||||
|
||||
const global char * mask_src_base = (const global char *) mask_void + mask_offset +
|
||||
(ulong) mask_batch_idx * mask_nb3 +
|
||||
(ulong) mask_head_idx * mask_nb2 +
|
||||
(ulong) q_row * mask_nb1;
|
||||
const global half * mask_src = (const global half *) mask_src_base;
|
||||
|
||||
global half * mask_pad = (global half *) mask_pad_void;
|
||||
const ulong dst_idx =
|
||||
(((ulong) mask_batch_idx * (ulong) mask_ne2 + (ulong) mask_head_idx) * (ulong) n_q + (ulong) q_row) * (ulong) BLOCK_N +
|
||||
(ulong) col_idx;
|
||||
|
||||
mask_pad[dst_idx] = src_col_idx < n_kv ? mask_src[src_col_idx] : (half) (-INFINITY);
|
||||
}
|
||||
|
||||
// Per-KV-tile mask class. 0=all -inf (skip tile), 1=mixed (apply mask),
|
||||
// 2=all zero, no -inf (skip mask lookup). Causal diagonal tiles are class 1.
|
||||
__kernel void flash_attn_blk_f16(
|
||||
const global void * mask_void, ulong mask_offset,
|
||||
global char * blk,
|
||||
const int n_q,
|
||||
const int n_kv,
|
||||
const ulong mask_nb1,
|
||||
const ulong mask_nb2,
|
||||
const ulong mask_nb3,
|
||||
const int mask_ne2,
|
||||
const int mask_ne3
|
||||
) {
|
||||
const int kv_block_idx = get_global_id(0);
|
||||
const int q_block_idx = get_global_id(1);
|
||||
const int mask_slice = get_global_id(2);
|
||||
|
||||
const int n_q_blocks = (n_q + BLOCK_M - 1) / BLOCK_M;
|
||||
const int n_kv_blocks = (n_kv + BLOCK_N - 1) / BLOCK_N;
|
||||
if (kv_block_idx >= n_kv_blocks || q_block_idx >= n_q_blocks || mask_slice >= mask_ne2 * mask_ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int mask_head_idx = mask_slice % mask_ne2;
|
||||
const int mask_batch_idx = mask_slice / mask_ne2;
|
||||
const int q_start = q_block_idx * BLOCK_M;
|
||||
const int k_start = kv_block_idx * BLOCK_N;
|
||||
const int q_count = min(BLOCK_M, n_q - q_start);
|
||||
const int k_count = min(BLOCK_N, n_kv - k_start);
|
||||
|
||||
const half neg_max_half = (half) (-65504.0f);
|
||||
char has_unmasked = 0;
|
||||
char has_masked = 0;
|
||||
char has_nonzero = 0;
|
||||
|
||||
const global char * mask_base = (const global char *) mask_void + mask_offset +
|
||||
(ulong) mask_batch_idx * mask_nb3 +
|
||||
(ulong) mask_head_idx * mask_nb2;
|
||||
|
||||
for (int qi = 0; qi < q_count; ++qi) {
|
||||
const global half * mask_row = (const global half *) (mask_base + (ulong) (q_start + qi) * mask_nb1) + k_start;
|
||||
for (int ki = 0; ki < k_count; ++ki) {
|
||||
const half v = mask_row[ki];
|
||||
if (v <= neg_max_half) {
|
||||
has_masked = 1;
|
||||
} else {
|
||||
has_unmasked = 1;
|
||||
if (v != (half) 0.0f) {
|
||||
has_nonzero = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (has_masked && has_unmasked) break; // mixed tile — short-circuit.
|
||||
}
|
||||
|
||||
char res;
|
||||
if (has_unmasked == 0) {
|
||||
res = 0;
|
||||
} else if (has_masked || has_nonzero) {
|
||||
res = 1;
|
||||
} else {
|
||||
res = 2;
|
||||
}
|
||||
|
||||
blk[((ulong) mask_slice * (ulong) n_q_blocks + (ulong) q_block_idx) * (ulong) n_kv_blocks + (ulong) kv_block_idx] = res;
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
// each work-item computes a 4 (rows of A / m) x 8 (cols of B / n) output tile.
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_128
|
||||
#endif
|
||||
kernel void kernel_gemm_noshuffle_q1_0_f32(
|
||||
global const uint * src0_q,
|
||||
global const half * src0_d,
|
||||
read_only image1d_buffer_t src1,
|
||||
global float * dst,
|
||||
int k,
|
||||
int m,
|
||||
int n,
|
||||
int n_no_padding,
|
||||
ulong offsetd
|
||||
) {
|
||||
int n_4 = n >> 2;
|
||||
|
||||
int gy = get_global_id(0);
|
||||
int gx = get_global_id(1);
|
||||
int gx_2 = gx << 2;
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
|
||||
half8 B;
|
||||
|
||||
global const uint* wptr = src0_q + gx_2;
|
||||
global const half* sptr = src0_d + gx_2;
|
||||
|
||||
// 32 weights per uint32, 128 weights (one block / one scale) per 4 uint32.
|
||||
for (int i = 0; i < k; i += 32) {
|
||||
uint4 pack4 = vload4(0, wptr + (i / 32) * m); // 4 rows, 32 K-values each
|
||||
half4 scale = vload4(0, sptr + (i / 128) * m); // 4 rows, one scale per 128
|
||||
|
||||
for (int j = 0; j < 32; ++j) {
|
||||
B.s0123 = read_imageh(src1, gy * 2 + (i + j) * n_4);
|
||||
B.s4567 = read_imageh(src1, gy * 2 + (i + j) * n_4 + 1);
|
||||
|
||||
// sign bit -> +-1 (half arithmetic avoids unsigned underflow)
|
||||
half4 wj = (half4)(
|
||||
2.0h * (half)((pack4.s0 >> j) & 1u) - 1.0h,
|
||||
2.0h * (half)((pack4.s1 >> j) & 1u) - 1.0h,
|
||||
2.0h * (half)((pack4.s2 >> j) & 1u) - 1.0h,
|
||||
2.0h * (half)((pack4.s3 >> j) & 1u) - 1.0h) * scale;
|
||||
|
||||
c0 += B * wj.s0;
|
||||
c1 += B * wj.s1;
|
||||
c2 += B * wj.s2;
|
||||
c3 += B * wj.s3;
|
||||
}
|
||||
}
|
||||
|
||||
int idx = (gy << 3) * m + (gx << 2);
|
||||
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
|
||||
idx += m;
|
||||
}
|
||||
if(idx+3 < m*n_no_padding){
|
||||
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
|
||||
#ifdef cl_qcom_reqd_sub_group_size
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#endif
|
||||
|
||||
#define QK1_0 128
|
||||
#define N_SIMDGROUP 4
|
||||
|
||||
#define dequantizeBlockAccum_q1(total, bits, scale, regB, lb) \
|
||||
total += (2.0f*(float)((bits >> 0) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+0); \
|
||||
total += (2.0f*(float)((bits >> 1) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+0); \
|
||||
total += (2.0f*(float)((bits >> 2) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+0); \
|
||||
total += (2.0f*(float)((bits >> 3) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+0); \
|
||||
total += (2.0f*(float)((bits >> 4) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+0); \
|
||||
total += (2.0f*(float)((bits >> 5) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+0); \
|
||||
total += (2.0f*(float)((bits >> 6) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+0); \
|
||||
total += (2.0f*(float)((bits >> 7) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+0); \
|
||||
total += (2.0f*(float)((bits >> 8) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+1); \
|
||||
total += (2.0f*(float)((bits >> 9) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+1); \
|
||||
total += (2.0f*(float)((bits >> 10) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+1); \
|
||||
total += (2.0f*(float)((bits >> 11) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+1); \
|
||||
total += (2.0f*(float)((bits >> 12) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+1); \
|
||||
total += (2.0f*(float)((bits >> 13) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+1); \
|
||||
total += (2.0f*(float)((bits >> 14) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+1); \
|
||||
total += (2.0f*(float)((bits >> 15) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+1); \
|
||||
total += (2.0f*(float)((bits >> 16) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+2); \
|
||||
total += (2.0f*(float)((bits >> 17) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+2); \
|
||||
total += (2.0f*(float)((bits >> 18) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+2); \
|
||||
total += (2.0f*(float)((bits >> 19) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+2); \
|
||||
total += (2.0f*(float)((bits >> 20) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+2); \
|
||||
total += (2.0f*(float)((bits >> 21) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+2); \
|
||||
total += (2.0f*(float)((bits >> 22) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+2); \
|
||||
total += (2.0f*(float)((bits >> 23) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+2); \
|
||||
total += (2.0f*(float)((bits >> 24) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s0, lb+3); \
|
||||
total += (2.0f*(float)((bits >> 25) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s1, lb+3); \
|
||||
total += (2.0f*(float)((bits >> 26) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s2, lb+3); \
|
||||
total += (2.0f*(float)((bits >> 27) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s3, lb+3); \
|
||||
total += (2.0f*(float)((bits >> 28) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s4, lb+3); \
|
||||
total += (2.0f*(float)((bits >> 29) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s5, lb+3); \
|
||||
total += (2.0f*(float)((bits >> 30) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s6, lb+3); \
|
||||
total += (2.0f*(float)((bits >> 31) & 1u) - 1.0f) * scale * sub_group_broadcast(regB.s7, lb+3);
|
||||
|
||||
|
||||
#ifdef ADRENO_GPU
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
__kernel void kernel_gemv_noshuffle_q1_0_f32(
|
||||
read_only image1d_buffer_t src0_q,
|
||||
global half * src0_d,
|
||||
read_only image1d_buffer_t src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne10,
|
||||
int ne12,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3)
|
||||
{
|
||||
uint groupId = get_local_id(1);
|
||||
uint gid = get_global_id(0);
|
||||
ushort slid = get_sub_group_local_id();
|
||||
|
||||
uint K = ne00;
|
||||
uint M = ne01;
|
||||
|
||||
uint LINE_STRIDE_A = M;
|
||||
uint BLOCK_STRIDE_A = 4 * M;
|
||||
|
||||
uint4 regA;
|
||||
half regS;
|
||||
float8 regB;
|
||||
|
||||
float totalSum = 0.0f;
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint kb = groupId; kb < (K / QK1_0); kb += N_SIMDGROUP) {
|
||||
regS = src0_d[gid + kb * LINE_STRIDE_A]; // each fiber loads its row's scale
|
||||
|
||||
// first 16 fibers load 8 B values each -> 128 activations for this block
|
||||
if (slid < 16) {
|
||||
regB.s0123 = read_imagef(src1, (slid * 2 + kb * 32));
|
||||
regB.s4567 = read_imagef(src1, (1 + slid * 2 + kb * 32));
|
||||
}
|
||||
|
||||
// load this row's 4 uint32 (128 sign bits)
|
||||
regA.s0 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
|
||||
regA.s1 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
|
||||
regA.s2 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
|
||||
regA.s3 = read_imageui(src0_q, (gid + kb * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
|
||||
|
||||
float scale = (float)regS;
|
||||
dequantizeBlockAccum_q1(totalSum, regA.s0, scale, regB, 0);
|
||||
dequantizeBlockAccum_q1(totalSum, regA.s1, scale, regB, 4);
|
||||
dequantizeBlockAccum_q1(totalSum, regA.s2, scale, regB, 8);
|
||||
dequantizeBlockAccum_q1(totalSum, regA.s3, scale, regB, 12);
|
||||
}
|
||||
|
||||
// reduction in local memory, assumes #wave = N_SIMDGROUP = 4
|
||||
local float reduceLM[SIMDGROUP_WIDTH * 3];
|
||||
if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
|
||||
if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
|
||||
if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
|
||||
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
|
||||
if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
|
||||
|
||||
if (groupId == 0) {
|
||||
dst = (global float*)((global char*)dst + offsetd);
|
||||
dst[gid] = totalSum;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
// LOAD_VEC_A is 8 because one q1_0 quant byte expands to 8 weights along K.
|
||||
#define LOAD_VEC_A 8
|
||||
#define LOAD_VEC_B 4
|
||||
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 32
|
||||
#define TM 4
|
||||
#define TN 8
|
||||
|
||||
kernel void kernel_mul_mm_q1_0_f32_l4_lm(
|
||||
global uchar * src0_q,
|
||||
global half * src0_d,
|
||||
global float4 * src1,
|
||||
ulong offset1,
|
||||
global float * dst,
|
||||
ulong offsetd,
|
||||
|
||||
int ne00,
|
||||
int ne01,
|
||||
int ne02,
|
||||
int ne11,
|
||||
int ne12,
|
||||
|
||||
int stride_a,
|
||||
int stride_b,
|
||||
int stride_d,
|
||||
|
||||
int batch_stride_a,
|
||||
int batch_stride_b,
|
||||
int batch_stride_d,
|
||||
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global float4*)((global char*)src1 + offset1);
|
||||
dst = (global float *)((global char*)dst + offsetd);
|
||||
|
||||
local float buf_a[BM * BK];
|
||||
local float buf_b[BN * BK];
|
||||
|
||||
const int batch_idx = get_global_id(2);
|
||||
|
||||
const int i13 = batch_idx / ne12;
|
||||
const int i12 = batch_idx % ne12;
|
||||
|
||||
const int i03 = i13 / r3;
|
||||
const int i02 = i12 / r2;
|
||||
|
||||
const int batch_idx_a = i03 * ne02 + i02;
|
||||
|
||||
const int ir = get_group_id(0);
|
||||
const int ic = get_group_id(1);
|
||||
|
||||
const int tid = get_local_id(0);
|
||||
const int th_r = tid % (BM / TM);
|
||||
const int th_c = tid / (BM / TM);
|
||||
|
||||
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
|
||||
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
|
||||
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
|
||||
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
|
||||
|
||||
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
|
||||
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
|
||||
|
||||
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
|
||||
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
|
||||
|
||||
float sums[TM * TN];
|
||||
float cache_a[TM];
|
||||
float cache_b[TN];
|
||||
|
||||
for (int i = 0; i < TM * TN; i++) {
|
||||
sums[i] = 0.0f;
|
||||
}
|
||||
|
||||
for (int block = 0; block < ne00; block += BK) {
|
||||
for (int l = 0; l < BM; l += loadstride_a) {
|
||||
if (ir*BM + loadc_a + l < ne01) {
|
||||
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
|
||||
int ib = idx / 16; // 16 quant bytes per q1_0 block
|
||||
|
||||
float d = (float)src0_d[ib];
|
||||
uint bits = src0_q[idx];
|
||||
|
||||
// use float to avoid unsigned underflow of (2*0 - 1).
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 0) & 1) - 1.0f);
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 1) & 1) - 1.0f);
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 2) & 1) - 1.0f);
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 3) & 1) - 1.0f);
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 4) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 4) & 1) - 1.0f);
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 5) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 5) & 1) - 1.0f);
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 6) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 6) & 1) - 1.0f);
|
||||
buf_a[(loadr_a * LOAD_VEC_A + 7) * BM + loadc_a + l] = d * (2.0f*(float)((bits >> 7) & 1) - 1.0f);
|
||||
} else {
|
||||
for (int b = 0; b < LOAD_VEC_A; ++b) {
|
||||
buf_a[(loadr_a * LOAD_VEC_A + b) * BM + loadc_a + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int l = 0; l < BN; l += loadstride_b) {
|
||||
if (ic*BN + loadc_b + l < ne11) {
|
||||
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
|
||||
} else {
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
|
||||
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
|
||||
for (int i = 0; i < BK; i++) {
|
||||
for (int j = 0; j < TM; j++) {
|
||||
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
|
||||
}
|
||||
|
||||
for (int j = 0; j < TN; j++) {
|
||||
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
|
||||
}
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
const int sums_idx = cc*TM + cr;
|
||||
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
}
|
||||
|
||||
const int dr = ir * BM + th_r * TM;
|
||||
const int dc = ic * BN + th_c * TN;
|
||||
|
||||
const int offsets = batch_idx * batch_stride_d;
|
||||
|
||||
for (int cc = 0; cc < TN; cc++) {
|
||||
for (int cr = 0; cr < TM; cr++) {
|
||||
if (dr + cr < ne01 && dc + cc < ne11) {
|
||||
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK1_0 128
|
||||
typedef struct {
|
||||
half d;
|
||||
uchar qs[QK1_0/8];
|
||||
} block_q1_0;
|
||||
|
||||
#define NB_Q1_0 16
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_R0_Q1_0 4 // number of rows each subgroup works on
|
||||
#define N_SG_Q1_0 2 // number of subgroups in a work group
|
||||
#define N_SIMDWIDTH 16 // subgroup size
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_R0_Q1_0 4
|
||||
#define N_SG_Q1_0 2
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
inline float block_q_1_0_dot_y(global block_q1_0 * qb, float sumy, float yl[NB_Q1_0], short il) {
|
||||
global uchar * qs = qb->qs + il*2;
|
||||
uint b0 = qs[0];
|
||||
uint b1 = qs[1];
|
||||
|
||||
float acc = 0.f;
|
||||
acc += yl[ 0]*(float)((b0 >> 0) & 1) + yl[ 1]*(float)((b0 >> 1) & 1);
|
||||
acc += yl[ 2]*(float)((b0 >> 2) & 1) + yl[ 3]*(float)((b0 >> 3) & 1);
|
||||
acc += yl[ 4]*(float)((b0 >> 4) & 1) + yl[ 5]*(float)((b0 >> 5) & 1);
|
||||
acc += yl[ 6]*(float)((b0 >> 6) & 1) + yl[ 7]*(float)((b0 >> 7) & 1);
|
||||
|
||||
acc += yl[ 8]*(float)((b1 >> 0) & 1) + yl[ 9]*(float)((b1 >> 1) & 1);
|
||||
acc += yl[10]*(float)((b1 >> 2) & 1) + yl[11]*(float)((b1 >> 3) & 1);
|
||||
acc += yl[12]*(float)((b1 >> 4) & 1) + yl[13]*(float)((b1 >> 5) & 1);
|
||||
acc += yl[14]*(float)((b1 >> 6) & 1) + yl[15]*(float)((b1 >> 7) & 1);
|
||||
|
||||
return qb->d * (2.0f*acc - sumy);
|
||||
}
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q1_0_f32(
|
||||
global char * src0,
|
||||
ulong offset0,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src0 = (global char*)((global char*)src0 + offset0);
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global char*)((global char*)dst + offsetd);
|
||||
|
||||
int nb = ne00/QK1_0;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0*N_SG_Q1_0 + get_sub_group_id()) * N_R0_Q1_0;
|
||||
|
||||
uint i12 = im%ne12;
|
||||
uint i13 = im/ne12;
|
||||
|
||||
ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;
|
||||
global float * y = (global float *) (src1 + offset_src1);
|
||||
|
||||
// pointers to src0 rows
|
||||
global block_q1_0 * ax[N_R0_Q1_0];
|
||||
for (int row = 0; row < N_R0_Q1_0; ++row) {
|
||||
ulong offset_src0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
ax[row] = (global block_q1_0 *) ((global char *) src0 + offset_src0);
|
||||
}
|
||||
|
||||
float yl[NB_Q1_0];
|
||||
float sumf[N_R0_Q1_0] = { 0.f };
|
||||
|
||||
const short ix = get_sub_group_local_id()/8;
|
||||
const short il = get_sub_group_local_id()%8;
|
||||
|
||||
global float * yb = y + ix*QK1_0 + il*NB_Q1_0;
|
||||
|
||||
// each thread handles NB_Q1_0 quants at a time
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
|
||||
float sumy = 0.f;
|
||||
for (short i = 0; i < NB_Q1_0; ++i) {
|
||||
yl[i] = yb[i];
|
||||
sumy += yb[i];
|
||||
}
|
||||
|
||||
for (short row = 0; row < N_R0_Q1_0; row++) {
|
||||
sumf[row] += block_q_1_0_dot_y(ax[row] + ib, sumy, yl, il);
|
||||
}
|
||||
|
||||
yb += N_SIMDWIDTH*NB_Q1_0;
|
||||
}
|
||||
|
||||
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
|
||||
|
||||
for (int row = 0; row < N_R0_Q1_0; ++row) {
|
||||
float tot = sub_group_reduce_add(sumf[row]);
|
||||
|
||||
if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
|
||||
#ifdef cl_intel_subgroups
|
||||
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
|
||||
#else
|
||||
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
|
||||
#endif
|
||||
|
||||
#ifdef cl_intel_required_subgroup_size
|
||||
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
|
||||
#define INTEL_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
|
||||
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
|
||||
#elif defined(cl_qcom_reqd_sub_group_size)
|
||||
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
|
||||
#define ADRENO_GPU 1
|
||||
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
|
||||
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
|
||||
#endif
|
||||
|
||||
#define QK1_0 128
|
||||
#define QK1_0_BYTES (QK1_0/8) // 16 quant bytes per block
|
||||
#define QK1_0_BLK_BYTES (QK1_0_BYTES + 2) // d + qs in original tensor = 18
|
||||
|
||||
#define NB_Q1_0 16 // quants handled per thread (two qs bytes)
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
#define N_R0_Q1_0 4 // number of rows each subgroup works on
|
||||
#define N_SG_Q1_0 2 // number of subgroups in a work group
|
||||
#define N_SIMDWIDTH 16 // subgroup size
|
||||
#elif defined (ADRENO_GPU)
|
||||
#define N_R0_Q1_0 4
|
||||
#define N_SG_Q1_0 2
|
||||
#define N_SIMDWIDTH 64
|
||||
#endif
|
||||
|
||||
#ifdef INTEL_GPU
|
||||
REQD_SUBGROUP_SIZE_16
|
||||
#elif defined (ADRENO_GPU)
|
||||
REQD_SUBGROUP_SIZE_64
|
||||
#endif
|
||||
kernel void kernel_mul_mv_q1_0_f32_flat(
|
||||
global char * src0_q,
|
||||
global half * src0_d,
|
||||
global char * src1,
|
||||
ulong offset1,
|
||||
global char * dst,
|
||||
ulong offsetd,
|
||||
int ne00,
|
||||
int ne01,
|
||||
ulong nb01,
|
||||
ulong nb02,
|
||||
ulong nb03,
|
||||
int ne12,
|
||||
ulong nb11,
|
||||
ulong nb12,
|
||||
ulong nb13,
|
||||
int ne0,
|
||||
int ne1,
|
||||
int r2,
|
||||
int r3
|
||||
) {
|
||||
src1 = (global char*)((global char*)src1 + offset1);
|
||||
dst = (global char*)((global char*)dst + offsetd);
|
||||
|
||||
int nb = ne00/QK1_0;
|
||||
|
||||
int r0 = get_group_id(0);
|
||||
int r1 = get_group_id(1);
|
||||
int im = get_group_id(2);
|
||||
|
||||
int first_row = (r0*N_SG_Q1_0 + get_sub_group_id()) * N_R0_Q1_0;
|
||||
|
||||
uint i12 = im%ne12;
|
||||
uint i13 = im/ne12;
|
||||
|
||||
ulong offset_src1 = r1*nb11 + i12*nb12 + i13*nb13;
|
||||
global float * y = (global float *) (src1 + offset_src1);
|
||||
|
||||
// pointers to src0 rows (flat: q bytes + scales)
|
||||
uint offset_src0_base = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
|
||||
|
||||
global uchar * ax0, * ax1, * ax2, * ax3;
|
||||
global half * ad0, * ad1, * ad2, * ad3;
|
||||
uint offset_src0;
|
||||
|
||||
offset_src0 = (offset_src0_base + 0*nb01) / QK1_0_BLK_BYTES;
|
||||
ax0 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
|
||||
ad0 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
|
||||
|
||||
offset_src0 = (offset_src0_base + 1*nb01) / QK1_0_BLK_BYTES;
|
||||
ax1 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
|
||||
ad1 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
|
||||
|
||||
offset_src0 = (offset_src0_base + 2*nb01) / QK1_0_BLK_BYTES;
|
||||
ax2 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
|
||||
ad2 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
|
||||
|
||||
offset_src0 = (offset_src0_base + 3*nb01) / QK1_0_BLK_BYTES;
|
||||
ax3 = (global uchar *) ((global char *) src0_q + offset_src0*QK1_0_BYTES);
|
||||
ad3 = (global half *) ((global char *) src0_d + offset_src0*sizeof(half));
|
||||
|
||||
const short ix = get_sub_group_local_id()/8;
|
||||
const short il = get_sub_group_local_id()%8;
|
||||
|
||||
global float * yb = y + ix*QK1_0 + il*NB_Q1_0;
|
||||
|
||||
float8 yl_lo;
|
||||
float8 yl_hi;
|
||||
float4 sumf = 0.f;
|
||||
|
||||
// each thread handles NB_Q1_0 = 16 quants (two qs bytes) at a time
|
||||
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
|
||||
yl_lo = vload8(0, yb);
|
||||
yl_hi = vload8(0, yb + 8);
|
||||
float sumy = yl_lo.s0 + yl_lo.s1 + yl_lo.s2 + yl_lo.s3
|
||||
+ yl_lo.s4 + yl_lo.s5 + yl_lo.s6 + yl_lo.s7
|
||||
+ yl_hi.s0 + yl_hi.s1 + yl_hi.s2 + yl_hi.s3
|
||||
+ yl_hi.s4 + yl_hi.s5 + yl_hi.s6 + yl_hi.s7;
|
||||
|
||||
uint b0, b1;
|
||||
float acc;
|
||||
|
||||
b0 = ax0[ib*QK1_0_BYTES + il*2 + 0];
|
||||
b1 = ax0[ib*QK1_0_BYTES + il*2 + 1];
|
||||
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
|
||||
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
|
||||
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
|
||||
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
|
||||
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
|
||||
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
|
||||
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
|
||||
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
|
||||
sumf.s0 += (float)ad0[ib] * (2.0f*acc - sumy);
|
||||
|
||||
b0 = ax1[ib*QK1_0_BYTES + il*2 + 0];
|
||||
b1 = ax1[ib*QK1_0_BYTES + il*2 + 1];
|
||||
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
|
||||
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
|
||||
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
|
||||
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
|
||||
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
|
||||
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
|
||||
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
|
||||
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
|
||||
sumf.s1 += (float)ad1[ib] * (2.0f*acc - sumy);
|
||||
|
||||
b0 = ax2[ib*QK1_0_BYTES + il*2 + 0];
|
||||
b1 = ax2[ib*QK1_0_BYTES + il*2 + 1];
|
||||
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
|
||||
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
|
||||
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
|
||||
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
|
||||
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
|
||||
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
|
||||
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
|
||||
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
|
||||
sumf.s2 += (float)ad2[ib] * (2.0f*acc - sumy);
|
||||
|
||||
b0 = ax3[ib*QK1_0_BYTES + il*2 + 0];
|
||||
b1 = ax3[ib*QK1_0_BYTES + il*2 + 1];
|
||||
acc = yl_lo.s0*(float)((b0 >> 0) & 1) + yl_lo.s1*(float)((b0 >> 1) & 1)
|
||||
+ yl_lo.s2*(float)((b0 >> 2) & 1) + yl_lo.s3*(float)((b0 >> 3) & 1)
|
||||
+ yl_lo.s4*(float)((b0 >> 4) & 1) + yl_lo.s5*(float)((b0 >> 5) & 1)
|
||||
+ yl_lo.s6*(float)((b0 >> 6) & 1) + yl_lo.s7*(float)((b0 >> 7) & 1)
|
||||
+ yl_hi.s0*(float)((b1 >> 0) & 1) + yl_hi.s1*(float)((b1 >> 1) & 1)
|
||||
+ yl_hi.s2*(float)((b1 >> 2) & 1) + yl_hi.s3*(float)((b1 >> 3) & 1)
|
||||
+ yl_hi.s4*(float)((b1 >> 4) & 1) + yl_hi.s5*(float)((b1 >> 5) & 1)
|
||||
+ yl_hi.s6*(float)((b1 >> 6) & 1) + yl_hi.s7*(float)((b1 >> 7) & 1);
|
||||
sumf.s3 += (float)ad3[ib] * (2.0f*acc - sumy);
|
||||
|
||||
yb += N_SIMDWIDTH*NB_Q1_0;
|
||||
}
|
||||
|
||||
global float * dst_f32 = (global float *) dst + (ulong)im*ne0*ne1 + (ulong)r1*ne0;
|
||||
|
||||
float4 tot = (float4)(
|
||||
sub_group_reduce_add(sumf.s0),
|
||||
sub_group_reduce_add(sumf.s1),
|
||||
sub_group_reduce_add(sumf.s2),
|
||||
sub_group_reduce_add(sumf.s3)
|
||||
);
|
||||
|
||||
if (get_sub_group_local_id() == 0) {
|
||||
if (first_row + 0 < ne01) dst_f32[first_row + 0] = tot.s0;
|
||||
if (first_row + 1 < ne01) dst_f32[first_row + 1] = tot.s1;
|
||||
if (first_row + 2 < ne01) dst_f32[first_row + 2] = tot.s2;
|
||||
if (first_row + 3 < ne01) dst_f32[first_row + 3] = tot.s3;
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user