mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-04-28 05:50:07 +02:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
434b2a1ff6 | ||
|
|
983ca8992e | ||
|
|
665abc6097 | ||
|
|
4414c04b9a | ||
|
|
ceaf47c4b1 | ||
|
|
42401c72b8 |
@@ -856,7 +856,7 @@ void common_memory_breakdown_print(const struct llama_context * ctx) {
|
||||
ggml_backend_dev_memory(dev, &free, &total);
|
||||
|
||||
const size_t self = mb.model + mb.context + mb.compute;
|
||||
const size_t unaccounted = total - self - free;
|
||||
const int64_t unaccounted = static_cast<int64_t>(total) - static_cast<int64_t>(free) - static_cast<int64_t>(self);
|
||||
|
||||
table_data.push_back({
|
||||
template_gpu,
|
||||
@@ -867,7 +867,7 @@ void common_memory_breakdown_print(const struct llama_context * ctx) {
|
||||
std::to_string(mb.model / MiB),
|
||||
std::to_string(mb.context / MiB),
|
||||
std::to_string(mb.compute / MiB),
|
||||
std::to_string(unaccounted / MiB)});
|
||||
std::to_string(unaccounted / static_cast<int64_t>(MiB))});
|
||||
}
|
||||
|
||||
// print memory breakdown for host:
|
||||
|
||||
@@ -1101,7 +1101,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
|
||||
fs::path cache_file = fs::path(cache_dir) / hash_str;
|
||||
std::ofstream ofs(cache_file, std::ios::binary);
|
||||
ofs.write((const char *)data, size);
|
||||
GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str());
|
||||
GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.string().c_str());
|
||||
}
|
||||
ggml_backend_tensor_set(tensor, data, offset, size);
|
||||
return true;
|
||||
|
||||
@@ -1287,6 +1287,7 @@ class ggml_webgpu_shader_lib {
|
||||
std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper);
|
||||
|
||||
switch (key.src_type) {
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
@@ -1323,7 +1324,9 @@ class ggml_webgpu_shader_lib {
|
||||
|
||||
defines.push_back("DST_TYPE=f32");
|
||||
|
||||
if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
if (key.src_type == GGML_TYPE_Q1_0) {
|
||||
defines.push_back("BLOCK_SIZE=128u");
|
||||
} else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) ||
|
||||
key.src_type == GGML_TYPE_IQ4_NL) {
|
||||
defines.push_back("BLOCK_SIZE=32u");
|
||||
} else if (key.src_type >= GGML_TYPE_Q2_K) {
|
||||
@@ -1615,6 +1618,24 @@ class ggml_webgpu_shader_lib {
|
||||
defines.push_back("MUL_ACC_" + type_upper);
|
||||
defines.push_back("U32_DEQUANT_HELPERS");
|
||||
defines.push_back("SRC0_INNER_TYPE=u32");
|
||||
switch (context.src0->type) {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
defines.push_back(type_upper + "_GRID");
|
||||
defines.push_back(type_upper + "_TABLES");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -1639,7 +1660,9 @@ class ggml_webgpu_shader_lib {
|
||||
uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE;
|
||||
uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG;
|
||||
|
||||
if (key.src0_type >= GGML_TYPE_Q2_K) {
|
||||
if (key.src0_type == GGML_TYPE_Q1_0) {
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
||||
} else if (key.src0_type >= GGML_TYPE_Q2_K) {
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG;
|
||||
} else if (key.src0_type >= GGML_TYPE_Q4_0) {
|
||||
outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG;
|
||||
|
||||
@@ -1389,8 +1389,20 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx,
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q3_K:
|
||||
case GGML_TYPE_Q2_K:
|
||||
case GGML_TYPE_Q1_0:
|
||||
use_fast = true;
|
||||
break;
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_M:
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
use_fast = is_vec;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -3725,6 +3737,7 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
|
||||
|
||||
static bool ggml_webgpu_supported_qtype(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -3819,6 +3832,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -3857,6 +3871,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q1_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
|
||||
@@ -27,6 +27,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q1_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_byte_base = (src_base + offset) * 18;
|
||||
let d = load_f16_as_f32_at_src(block_byte_base);
|
||||
for (var j: u32 = 0u; j < 4u; j++) {
|
||||
let q_packed = load_u32_at_src(block_byte_base + 2u + j * 4u);
|
||||
let dst_base128 = dst_base + offset * 128u + j * 32u;
|
||||
for (var k: u32 = 0; k < 4u; k++) {
|
||||
let q_byte = get_byte(q_packed, k);
|
||||
for (var bit: u32 = 0; bit < 8u; bit++) {
|
||||
let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
|
||||
dst[dst_base128 + k * 8u + bit] = w;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef Q4_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes
|
||||
|
||||
@@ -61,6 +61,39 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
|
||||
#endif // INIT_SRC1_SHMEM_FLOAT
|
||||
#endif
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q1_0
|
||||
const BLOCK_SIZE = 128u;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration
|
||||
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
|
||||
let tile_m = i / TILE_K;
|
||||
let tile_k_start = i % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k_start = k_outer + tile_k_start;
|
||||
|
||||
if (global_m >= params.m) {
|
||||
break;
|
||||
}
|
||||
|
||||
let block_k = global_k_start / BLOCK_SIZE;
|
||||
let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u;
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + block_k;
|
||||
let block_byte_base = src0_idx * BLOCK_SIZE_BYTES;
|
||||
let d = load_f16_at_src0(block_byte_base);
|
||||
let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu;
|
||||
|
||||
for (var bit = 0u; bit < NQ; bit++) {
|
||||
let global_k = global_k_start + bit;
|
||||
if (global_k < params.k) {
|
||||
shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif // INIT_SRC0_SHMEM_Q1_0
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
const BLOCK_SIZE_BYTES = 18u;
|
||||
|
||||
@@ -128,6 +128,38 @@ fn main(
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q1_0
|
||||
#define BLOCK_SIZE 128
|
||||
#define BLOCK_SIZE_BYTES 18
|
||||
#define THREADS_PER_BLOCK 16
|
||||
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
let thread_within_block = thread_id % THREADS_PER_BLOCK;
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD;
|
||||
var x_block: array<f32, ELEMS_PER_THREAD>;
|
||||
for (var i = 0u; i < ELEMS_PER_THREAD; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu;
|
||||
var row_sum = 0.0;
|
||||
for (var bit = 0u; bit < 8u; bit++) {
|
||||
let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u);
|
||||
row_sum += w * x_block[bit];
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_Q4_0
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCK_SIZE_BYTES 18
|
||||
@@ -812,6 +844,520 @@ fn main(
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ1_S
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 50
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu;
|
||||
let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u);
|
||||
let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let qs_byte = get_byte(qs_w, l);
|
||||
let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u;
|
||||
let gw = iq1_grid[ig / 16u];
|
||||
let bit_base = (ig % 16u) * 2u;
|
||||
for (var j = 0u; j < 8u; j++) {
|
||||
let g = (gw >> (bit_base + j * 2u)) & 3u;
|
||||
let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
|
||||
row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ1_M
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 56
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
|
||||
let sc_lo = load_u32_at_src0(block_byte_base + 48u);
|
||||
let sc_hi = load_u32_at_src0(block_byte_base + 52u);
|
||||
let sc0 = sc_lo & 0xFFFFu;
|
||||
let sc1 = (sc_lo >> 16u) & 0xFFFFu;
|
||||
let sc2 = sc_hi & 0xFFFFu;
|
||||
let sc3 = (sc_hi >> 16u) & 0xFFFFu;
|
||||
let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u);
|
||||
let d = f32(bitcast<vec2<f16>>(d_bits)[0]);
|
||||
|
||||
let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u),
|
||||
select(sc0, sc1, sub_blk >= 2u),
|
||||
sub_blk < 4u);
|
||||
|
||||
let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u);
|
||||
let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu;
|
||||
let qh_lo = qh & 0xFFu;
|
||||
let qh_hi = (qh >> 8u) & 0xFFu;
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u);
|
||||
let sub_scale = (sc_u16 >> bit_off) & 0x7u;
|
||||
let dl = d * f32(2u * sub_scale + 1u);
|
||||
let qh_byte = select(qh_lo, qh_hi, l >= 2u);
|
||||
let ll2 = l % 2u;
|
||||
let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u);
|
||||
let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u);
|
||||
let ig = grid_idx * 8u;
|
||||
let gw = iq1_grid[ig / 16u];
|
||||
let bit_base = (ig % 16u) * 2u;
|
||||
for (var j = 0u; j < 8u; j++) {
|
||||
let g = (gw >> (bit_base + j * 2u)) & 3u;
|
||||
let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u);
|
||||
row_sum += dl * (gs + delta) * x_block[ll * 8u + j];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ2_XXS
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 66
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
|
||||
let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
|
||||
let ls = aux_hi >> 28u;
|
||||
let db = d * (0.5 + f32(ls)) * 0.25;
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let grid_idx = (aux_lo >> (8u * l)) & 0xFFu;
|
||||
let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu;
|
||||
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
|
||||
let gw_lo = iq2xxs_grid[grid_idx * 2u];
|
||||
let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u];
|
||||
for (var j = 0u; j < 8u; j++) {
|
||||
let gw = select(gw_hi, gw_lo, j < 4u);
|
||||
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
|
||||
let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
|
||||
row_sum += db * b * s * x_block[ll * 8u + j];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ2_XS
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 74
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
|
||||
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
|
||||
let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
|
||||
let scales_byte = get_byte(scales_word, sub_blk % 4u);
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let qs_word = select(qs_hi, qs_lo, l < 2u);
|
||||
let half2 = (l % 2u) * 16u;
|
||||
let qs_val = (qs_word >> half2) & 0xFFFFu;
|
||||
let grid_idx = qs_val & 0x1FFu;
|
||||
let signs_idx = (qs_val >> 9u) & 0x7Fu;
|
||||
let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
|
||||
let db = d * (0.5 + f32(sub_scale)) * 0.25;
|
||||
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
|
||||
let gw_lo = iq2xs_grid[grid_idx * 2u];
|
||||
let gw_hi = iq2xs_grid[grid_idx * 2u + 1u];
|
||||
for (var j = 0u; j < 8u; j++) {
|
||||
let gw = select(gw_hi, gw_lo, j < 4u);
|
||||
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
|
||||
let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
|
||||
row_sum += db * b * s * x_block[ll * 8u + j];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ2_S
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 82
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u);
|
||||
let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u);
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
|
||||
let qh_byte = get_byte(qh_word, sub_blk % 4u);
|
||||
let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u);
|
||||
let scales_byte = get_byte(sc_word, sub_blk % 4u);
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let qs_byte = get_byte(qs_w, l);
|
||||
let sign_byte = get_byte(sg_w, l);
|
||||
let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u);
|
||||
let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu;
|
||||
let db = d * (0.5 + f32(sub_scale)) * 0.25;
|
||||
let gw_lo = iq2s_grid[grid_idx * 2u];
|
||||
let gw_hi = iq2s_grid[grid_idx * 2u + 1u];
|
||||
for (var j = 0u; j < 8u; j++) {
|
||||
let gw = select(gw_hi, gw_lo, j < 4u);
|
||||
let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu);
|
||||
let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
|
||||
row_sum += db * b * s * x_block[ll * 8u + j];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ3_XXS
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 98
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
|
||||
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
|
||||
let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u);
|
||||
let ls = aux >> 28u;
|
||||
let db = d * (0.5 + f32(ls)) * 0.5;
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let qs_word = select(qs_hi, qs_lo, l < 2u);
|
||||
let byte_pos = (l % 2u) * 2u;
|
||||
let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
|
||||
let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
|
||||
let signs_idx = (aux >> (7u * l)) & 0x7Fu;
|
||||
let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu;
|
||||
let grid1 = iq3xxs_grid[grid_idx_0];
|
||||
let grid2 = iq3xxs_grid[grid_idx_1];
|
||||
for (var j = 0u; j < 4u; j++) {
|
||||
let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
|
||||
let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
|
||||
let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u);
|
||||
let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u);
|
||||
row_sum += db * b1 * s1 * x_block[ll * 8u + j];
|
||||
row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ3_S
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 110
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let slot0 = half * 2u;
|
||||
let y_offset = sub_blk * 32u + slot0 * 8u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u);
|
||||
let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u);
|
||||
let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u);
|
||||
let qh_byte = get_byte(qh_word, sub_blk % 4u);
|
||||
let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u);
|
||||
let sc_word = load_u32_at_src0(block_byte_base + 106u);
|
||||
let scales_byte = get_byte(sc_word, sub_blk / 2u);
|
||||
let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu;
|
||||
let db = d * (1.0 + 2.0 * f32(sub_scale));
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var ll = 0u; ll < 2u; ll++) {
|
||||
let l = slot0 + ll;
|
||||
let qs_word = select(qs_hi, qs_lo, l < 2u);
|
||||
let byte_pos = (l % 2u) * 2u;
|
||||
let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu;
|
||||
let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu;
|
||||
let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u);
|
||||
let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u);
|
||||
let sign_byte = get_byte(sg_w, l);
|
||||
let grid1 = iq3s_grid[grid_idx_1];
|
||||
let grid2 = iq3s_grid[grid_idx_2];
|
||||
for (var j = 0u; j < 4u; j++) {
|
||||
let b1 = f32((grid1 >> (j * 8u)) & 0xFFu);
|
||||
let b2 = f32((grid2 >> (j * 8u)) & 0xFFu);
|
||||
let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u);
|
||||
let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u);
|
||||
row_sum += db * b1 * s1 * x_block[ll * 8u + j];
|
||||
row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u];
|
||||
}
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ4_NL
|
||||
#define BLOCK_SIZE 32
|
||||
#define BLOCK_SIZE_BYTES 18
|
||||
#define THREADS_PER_BLOCK 4
|
||||
#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK)
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
let thread_within_block = thread_id % THREADS_PER_BLOCK;
|
||||
for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u;
|
||||
var x_block: array<f32, ELEMS_PER_THREAD>;
|
||||
for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
x_block[i + 4u] = f32(src1[x_base + i + 16u]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
var row_sum = 0.0;
|
||||
|
||||
let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block);
|
||||
for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) {
|
||||
let q_byte = get_byte(q_packed, byte_idx);
|
||||
let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d;
|
||||
let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d;
|
||||
row_sum += q_lo * x_block[byte_idx];
|
||||
row_sum += q_hi * x_block[byte_idx + 4u];
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_IQ4_XS
|
||||
#define BLOCK_SIZE 256
|
||||
#define BLOCK_SIZE_BYTES 136
|
||||
#define THREADS_PER_BLOCK 16
|
||||
|
||||
let tid = thread_id % THREADS_PER_BLOCK;
|
||||
let block_group = thread_id / THREADS_PER_BLOCK;
|
||||
let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK;
|
||||
|
||||
let sub_blk = tid / 2u;
|
||||
let half = tid % 2u;
|
||||
let y_offset = sub_blk * 32u + half * 16u;
|
||||
|
||||
let num_blocks = params.k / BLOCK_SIZE;
|
||||
|
||||
for (var block = block_group; block < num_blocks; block += num_block_groups) {
|
||||
let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset;
|
||||
var x_block: array<f32, 16>;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
x_block[i] = f32(src1[x_base + i]);
|
||||
}
|
||||
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let output_row = row_base + row;
|
||||
if (output_row < params.m) {
|
||||
let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES;
|
||||
let d = f32(load_f16_at_src0(block_byte_base));
|
||||
let scales_h = load_u16_at_src0(block_byte_base + 2u);
|
||||
let scales_l_word = load_u32_at_src0(block_byte_base + 4u);
|
||||
let sl_byte = get_byte(scales_l_word, sub_blk / 2u);
|
||||
let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu;
|
||||
let sh_bits = (scales_h >> (2u * sub_blk)) & 3u;
|
||||
let ls = i32(sl | (sh_bits << 4u));
|
||||
let dl = d * f32(ls - 32);
|
||||
|
||||
let qs_byte_off = 8u + sub_blk * 16u;
|
||||
let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off);
|
||||
let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u);
|
||||
let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u);
|
||||
let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u);
|
||||
|
||||
var row_sum = 0.0;
|
||||
for (var i = 0u; i < 16u; i++) {
|
||||
let q_word = select(
|
||||
select(q_w0, q_w1, i >= 4u),
|
||||
select(q_w2, q_w3, i >= 12u),
|
||||
i >= 8u);
|
||||
let q_byte = get_byte(q_word, i % 4u);
|
||||
let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u);
|
||||
row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i];
|
||||
}
|
||||
acc[row] += row_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_SUBGROUP_REDUCTION
|
||||
for (var row = 0u; row < OUTPUTS_PER_WG; row++) {
|
||||
let subgroup_total = subgroupAdd(acc[row]);
|
||||
|
||||
@@ -2249,6 +2249,46 @@ static void test_template_output_peg_parsers(bool detailed_debug) {
|
||||
.reasoning_format(COMMON_REASONING_FORMAT_AUTO)
|
||||
.expect(message_assist)
|
||||
.run();
|
||||
|
||||
{
|
||||
// additional tests for https://github.com/ggml-org/llama.cpp/pull/21760
|
||||
auto tmpls = read_templates("models/templates/google-gemma-4-31B-it.jinja");
|
||||
|
||||
common_chat_msg tool_call_msg = simple_assist_msg(
|
||||
"Let me check.", "", "special_function", "{\"arg1\": 1}","c0");
|
||||
|
||||
common_chat_msg tool_msg;
|
||||
tool_msg.role = "tool";
|
||||
tool_msg.tool_name = "special_function";
|
||||
tool_msg.tool_call_id = "c0";
|
||||
tool_msg.content = "{\"r\":\"ok\"}";
|
||||
|
||||
{
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = { message_user, tool_call_msg, tool_msg };
|
||||
inputs.tools = { special_function_tool };
|
||||
inputs.add_generation_prompt = true;
|
||||
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs);
|
||||
|
||||
if (!string_ends_with(params.prompt, "<turn|>\n<|turn>model\n")) {
|
||||
throw std::runtime_error("Missing generation prompt for Gemma 4");
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = { message_user, tool_call_msg, tool_msg };
|
||||
inputs.tools = { special_function_tool };
|
||||
inputs.add_generation_prompt = false;
|
||||
|
||||
auto params = common_chat_templates_apply(tmpls.get(), inputs);
|
||||
|
||||
if (string_ends_with(params.prompt, "<|turn>model\n")) {
|
||||
throw std::runtime_error("Gemma 4: generation prompt was modified despite add_generation_prompt=false");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
@@ -317,7 +317,7 @@ int main(int argc, char * argv[]) {
|
||||
const char * cache_dir = nullptr;
|
||||
std::string cache_dir_str;
|
||||
if (params.use_cache) {
|
||||
cache_dir_str = fs_get_cache_directory() + "rpc/";
|
||||
cache_dir_str = fs_get_cache_directory() + "rpc" + DIRECTORY_SEPARATOR;
|
||||
if (!fs_create_directory_with_parents(cache_dir_str)) {
|
||||
fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str());
|
||||
return 1;
|
||||
|
||||
@@ -575,14 +575,14 @@ json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) {
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & inp_body,
|
||||
const common_chat_templates * tmpls,
|
||||
const std::map<std::string, raw_buffer> & in_files,
|
||||
const std::map<std::string, uploaded_file> & in_files,
|
||||
std::vector<raw_buffer> & out_files) {
|
||||
// TODO @ngxson : this function may need to be improved in the future
|
||||
// handle input files
|
||||
out_files.clear();
|
||||
auto it = in_files.find("file");
|
||||
if (it != in_files.end()) {
|
||||
out_files.push_back(it->second);
|
||||
out_files.push_back(it->second.data);
|
||||
} else {
|
||||
throw std::invalid_argument("No input file found for transcription");
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "chat.h"
|
||||
#include "server-common.h"
|
||||
#include "server-http.h"
|
||||
|
||||
#include <nlohmann/json_fwd.hpp>
|
||||
|
||||
@@ -19,7 +20,7 @@ json server_chat_convert_anthropic_to_oai(const json & body);
|
||||
json convert_transcriptions_to_chatcmpl(
|
||||
const json & body,
|
||||
const common_chat_templates * tmpls,
|
||||
const std::map<std::string, raw_buffer> & in_files,
|
||||
const std::map<std::string, uploaded_file> & in_files,
|
||||
std::vector<raw_buffer> & out_files);
|
||||
|
||||
json server_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||
|
||||
@@ -49,6 +49,7 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin
|
||||
parsed_url.path,
|
||||
headers,
|
||||
req.body,
|
||||
req.files,
|
||||
req.should_stop,
|
||||
600, // timeout_read (default to 10 minutes)
|
||||
600 // timeout_write (default to 10 minutes)
|
||||
|
||||
@@ -438,7 +438,7 @@ void server_http_context::get(const std::string & path, const server_http_contex
|
||||
void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
|
||||
pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
|
||||
std::string body = req.body;
|
||||
std::map<std::string, raw_buffer> files;
|
||||
std::map<std::string, uploaded_file> files;
|
||||
|
||||
if (req.is_multipart_form_data()) {
|
||||
// translate text fields to a JSON object and use it as the body
|
||||
@@ -459,7 +459,11 @@ void server_http_context::post(const std::string & path, const server_http_conte
|
||||
|
||||
// populate files from multipart form
|
||||
for (const auto & [key, file] : req.form.files) {
|
||||
files[key] = raw_buffer(file.content.begin(), file.content.end());
|
||||
files[key] = uploaded_file{
|
||||
raw_buffer(file.content.begin(), file.content.end()),
|
||||
file.filename,
|
||||
file.content_type,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,13 +36,19 @@ struct server_http_res {
|
||||
using server_http_res_ptr = std::unique_ptr<server_http_res>;
|
||||
using raw_buffer = std::vector<uint8_t>;
|
||||
|
||||
struct uploaded_file {
|
||||
raw_buffer data;
|
||||
std::string filename;
|
||||
std::string content_type;
|
||||
};
|
||||
|
||||
struct server_http_req {
|
||||
std::map<std::string, std::string> params; // path_params + query_params
|
||||
std::map<std::string, std::string> headers; // used by MCP proxy
|
||||
std::string path;
|
||||
std::string query_string; // query parameters string (e.g. "action=save")
|
||||
std::string body;
|
||||
std::map<std::string, raw_buffer> files; // used for file uploads (form data)
|
||||
std::map<std::string, uploaded_file> files; // used for file uploads (form data)
|
||||
const std::function<bool()> & should_stop;
|
||||
|
||||
std::string get_param(const std::string & key, const std::string & def = "") const {
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
#include <chrono>
|
||||
#include <queue>
|
||||
#include <filesystem>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
|
||||
#ifdef _WIN32
|
||||
@@ -823,6 +825,7 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co
|
||||
proxy_path,
|
||||
req.headers,
|
||||
req.body,
|
||||
req.files,
|
||||
req.should_stop,
|
||||
base_params.timeout_read,
|
||||
base_params.timeout_write
|
||||
@@ -1126,6 +1129,77 @@ static bool should_strip_proxy_header(const std::string & header_name) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::string generate_multipart_boundary() {
|
||||
thread_local std::mt19937 gen(std::random_device{}());
|
||||
static const char chars[] = "0123456789abcdefghijklmnopqrstuvwxyz";
|
||||
std::uniform_int_distribution<> dis(0, sizeof(chars) - 2);
|
||||
std::string boundary = "----llama-cpp-proxy-";
|
||||
for (int i = 0; i < 16; i++) {
|
||||
boundary += chars[dis(gen)];
|
||||
}
|
||||
return boundary;
|
||||
}
|
||||
|
||||
static std::string build_multipart_body(
|
||||
const json & form_fields,
|
||||
const std::map<std::string, uploaded_file> & files,
|
||||
const std::string & boundary) {
|
||||
static auto sanitize_field = [](const std::string & text) {
|
||||
std::string result;
|
||||
result.reserve(text.size());
|
||||
for (char c : text) {
|
||||
if (c != '\n' && c != '\r' && c != '"') {
|
||||
result += c;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
std::ostringstream body;
|
||||
|
||||
for (const auto & [key, value] : form_fields.items()) {
|
||||
if (value.is_array()) {
|
||||
for (const auto & item : value) {
|
||||
body << "--" << boundary << "\r\n";
|
||||
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"\r\n";
|
||||
body << "\r\n";
|
||||
if (!item.is_string()) {
|
||||
throw std::invalid_argument("expected string");
|
||||
}
|
||||
body << item.get<std::string>() << "\r\n";
|
||||
}
|
||||
} else {
|
||||
body << "--" << boundary << "\r\n";
|
||||
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"\r\n";
|
||||
body << "\r\n";
|
||||
if (!value.is_string()) {
|
||||
throw std::invalid_argument("expected string");
|
||||
}
|
||||
body << value.get<std::string>() << "\r\n";
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & [key, file] : files) {
|
||||
body << "--" << boundary << "\r\n";
|
||||
body << "Content-Disposition: form-data; name=\"" << sanitize_field(key) << "\"";
|
||||
if (!file.filename.empty()) {
|
||||
body << "; filename=\"" << sanitize_field(file.filename) << "\"";
|
||||
}
|
||||
body << "\r\n";
|
||||
if (!file.content_type.empty()) {
|
||||
body << "Content-Type: " << sanitize_field(file.content_type) << "\r\n";
|
||||
} else {
|
||||
body << "Content-Type: application/octet-stream\r\n";
|
||||
}
|
||||
body << "\r\n";
|
||||
body.write(reinterpret_cast<const char*>(file.data.data()), file.data.size());
|
||||
body << "\r\n";
|
||||
}
|
||||
|
||||
body << "--" << boundary << "--\r\n";
|
||||
return body.str();
|
||||
}
|
||||
|
||||
server_http_proxy::server_http_proxy(
|
||||
const std::string & method,
|
||||
const std::string & scheme,
|
||||
@@ -1134,6 +1208,7 @@ server_http_proxy::server_http_proxy(
|
||||
const std::string & path,
|
||||
const std::map<std::string, std::string> & headers,
|
||||
const std::string & body,
|
||||
const std::map<std::string, uploaded_file> & files,
|
||||
const std::function<bool()> should_stop,
|
||||
int32_t timeout_read,
|
||||
int32_t timeout_write
|
||||
@@ -1195,28 +1270,65 @@ server_http_proxy::server_http_proxy(
|
||||
return pipe->write({{}, 0, std::string(data, data_length), ""});
|
||||
};
|
||||
|
||||
// when files are present, the body was converted from multipart form data to JSON
|
||||
// we need to reconstruct the multipart body for the downstream server
|
||||
std::string effective_body = body;
|
||||
std::string override_content_type;
|
||||
bool has_files = !files.empty();
|
||||
|
||||
if (has_files) {
|
||||
json form_fields = json::parse(body, nullptr, false);
|
||||
if (!form_fields.is_discarded()) {
|
||||
auto boundary = generate_multipart_boundary();
|
||||
effective_body = build_multipart_body(form_fields, files, boundary);
|
||||
override_content_type = "multipart/form-data; boundary=" + boundary;
|
||||
} else {
|
||||
throw std::runtime_error("failed to parse multipart form fields JSON");
|
||||
}
|
||||
}
|
||||
|
||||
// prepare the request to destination server
|
||||
httplib::Request req;
|
||||
{
|
||||
req.method = method;
|
||||
req.path = path;
|
||||
for (const auto & [key, value] : headers) {
|
||||
if (key == "Accept-Encoding") {
|
||||
const auto lowered = to_lower_copy(key);
|
||||
if (lowered == "accept-encoding") {
|
||||
// disable Accept-Encoding to avoid compressed responses
|
||||
continue;
|
||||
}
|
||||
if (key == "Transfer-Encoding") {
|
||||
if (lowered == "transfer-encoding") {
|
||||
// the body is already decoded
|
||||
continue;
|
||||
}
|
||||
if (key == "Host" || key == "host") {
|
||||
if (lowered == "content-length") {
|
||||
// let httplib calculate Content-Length from the actual body
|
||||
continue;
|
||||
}
|
||||
if (lowered == "content-type") {
|
||||
if (has_files) {
|
||||
// we set our own Content-Type with the new boundary
|
||||
continue;
|
||||
}
|
||||
// when no files but the original request was multipart,
|
||||
// the body is now JSON, so correct the Content-Type
|
||||
if (value.find("multipart/form-data") != std::string::npos) {
|
||||
override_content_type = "application/json; charset=utf-8";
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (lowered == "host") {
|
||||
bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80);
|
||||
req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port));
|
||||
} else {
|
||||
req.set_header(key, value);
|
||||
}
|
||||
}
|
||||
req.body = body;
|
||||
req.body = effective_body;
|
||||
if (!override_content_type.empty()) {
|
||||
req.set_header("Content-Type", override_content_type);
|
||||
}
|
||||
req.response_handler = response_handler;
|
||||
req.content_receiver = content_receiver;
|
||||
}
|
||||
|
||||
@@ -202,6 +202,7 @@ public:
|
||||
const std::string & path,
|
||||
const std::map<std::string, std::string> & headers,
|
||||
const std::string & body,
|
||||
const std::map<std::string, uploaded_file> & files,
|
||||
const std::function<bool()> should_stop,
|
||||
int32_t timeout_read,
|
||||
int32_t timeout_write
|
||||
|
||||
Reference in New Issue
Block a user