mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2026-06-24 12:39:45 +02:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1191758c5d | |||
| 00139b660b | |||
| ef9c13d4c2 | |||
| 88636e178f | |||
| ac4105d68b | |||
| be4a6a63eb | |||
| 72a9269172 | |||
| 92e854ab83 | |||
| c5606364b2 | |||
| 0eb874d374 |
@@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
|
||||
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
|
||||
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
|
||||
- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2)
|
||||
- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25)
|
||||
- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos)
|
||||
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
|
||||
- [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86)
|
||||
- [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum)
|
||||
|
||||
+3
-9
@@ -603,8 +603,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
|
||||
|
||||
// model is required (except for server)
|
||||
// TODO @ngxson : maybe show a list of available models in CLI in this case
|
||||
bool can_skip_model = params.usage || params.completion || !params.server_base.empty();
|
||||
if (!can_skip_model && params.model.path.empty()) {
|
||||
if (params.model.path.empty()
|
||||
&& !params.usage
|
||||
&& !params.completion) {
|
||||
throw std::invalid_argument("error: --model is required\n");
|
||||
}
|
||||
}
|
||||
@@ -1118,13 +1119,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||
params.completion = true;
|
||||
}
|
||||
));
|
||||
add_opt(common_arg(
|
||||
{"--server-base"}, "URL",
|
||||
string_format("connect to this server instead of starting a new one, example: 'http://localhost:8080' (default: none)"),
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.server_base = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_CLI}));
|
||||
add_opt(common_arg(
|
||||
{"--verbose-prompt"},
|
||||
string_format("print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false"),
|
||||
|
||||
@@ -631,9 +631,6 @@ struct common_params {
|
||||
|
||||
std::map<std::string, std::string> default_template_kwargs;
|
||||
|
||||
// CLI params
|
||||
std::string server_base; // if set, connect to this server instead of starting a new one
|
||||
|
||||
// UI configs
|
||||
bool ui = true;
|
||||
bool ui_mcp_proxy = false;
|
||||
|
||||
@@ -2,16 +2,6 @@
|
||||
|
||||
#include <cpp-httplib/httplib.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
struct common_http_url {
|
||||
std::string scheme;
|
||||
std::string user;
|
||||
@@ -107,63 +97,3 @@ static std::pair<httplib::Client, common_http_url> common_http_client(const std:
|
||||
static std::string common_http_show_masked_url(const common_http_url & parts) {
|
||||
return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path;
|
||||
}
|
||||
|
||||
static int common_http_get_free_port() {
|
||||
#ifdef _WIN32
|
||||
WSADATA wsaData;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
|
||||
return -1;
|
||||
}
|
||||
typedef SOCKET native_socket_t;
|
||||
#define INVALID_SOCKET_VAL INVALID_SOCKET
|
||||
#define CLOSE_SOCKET(s) closesocket(s)
|
||||
#else
|
||||
typedef int native_socket_t;
|
||||
#define INVALID_SOCKET_VAL -1
|
||||
#define CLOSE_SOCKET(s) close(s)
|
||||
#endif
|
||||
|
||||
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock == INVALID_SOCKET_VAL) {
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct sockaddr_in serv_addr;
|
||||
std::memset(&serv_addr, 0, sizeof(serv_addr));
|
||||
serv_addr.sin_family = AF_INET;
|
||||
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
serv_addr.sin_port = htons(0);
|
||||
|
||||
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
int namelen = sizeof(serv_addr);
|
||||
#else
|
||||
socklen_t namelen = sizeof(serv_addr);
|
||||
#endif
|
||||
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
int port = ntohs(serv_addr.sin_port);
|
||||
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
|
||||
return port;
|
||||
}
|
||||
|
||||
@@ -124,6 +124,7 @@ TEXT_MODEL_MAP: dict[str, str] = {
|
||||
"LLaDAModelLM": "llada",
|
||||
"LLaMAForCausalLM": "llama",
|
||||
"Lfm25AudioTokenizer": "lfm2",
|
||||
"Lfm2BidirectionalModel": "lfm2",
|
||||
"Lfm2ForCausalLM": "lfm2",
|
||||
"Lfm2Model": "lfm2",
|
||||
"Lfm2MoeForCausalLM": "lfm2",
|
||||
|
||||
+10
-3
@@ -64,11 +64,17 @@ class LFM2Model(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Lfm2Model")
|
||||
@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel")
|
||||
class LFM2ColBertModel(LFM2Model):
|
||||
model_arch = gguf.MODEL_ARCH.LFM2
|
||||
dense_tensor_name = "dense_2"
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if self.hf_arch == "Lfm2BidirectionalModel":
|
||||
self.gguf_writer.add_causal_attention(False)
|
||||
self._try_set_pooling_type()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if not name.startswith(self.dense_tensor_name):
|
||||
name = "model." + name
|
||||
@@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# dense tensor is stored in a separate safetensors file
|
||||
# optional dense tensor is stored in a separate safetensors file
|
||||
from safetensors.torch import load_file
|
||||
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
|
||||
assert tensors_file.is_file()
|
||||
if not tensors_file.is_file():
|
||||
return
|
||||
tensor = load_file(tensors_file)["linear.weight"]
|
||||
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
|
||||
yield f"{self.dense_tensor_name}.weight", tensor.clone()
|
||||
|
||||
+50
-23
@@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32(
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32(
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, x);
|
||||
float mean = sum/ne00;
|
||||
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
|
||||
const float * xf = (const float *) x;
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
float variance = 0;
|
||||
float sum = 0.0;
|
||||
ggml_vec_sum_f32(ne00, &sum, xf);
|
||||
float mean = sum/ne00;
|
||||
|
||||
float * yf = (float *) y;
|
||||
float variance = 0;
|
||||
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
mean = -mean;
|
||||
vDSP_vsadd(x, 1, &mean, y, 1, ne00);
|
||||
vDSP_measqv(y, 1, &variance, ne00);
|
||||
mean = -mean;
|
||||
vDSP_vsadd(xf, 1, &mean, yf, 1, ne00);
|
||||
vDSP_measqv(yf, 1, &variance, ne00);
|
||||
#else
|
||||
variance = ggml_vec_cvar_f32(ne00, y, x, mean);
|
||||
variance = ggml_vec_cvar_f32(ne00, yf, xf, mean);
|
||||
#endif //GGML_USE_ACCELERATE
|
||||
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
ggml_vec_scale_f32(ne00, yf, scale);
|
||||
} else {
|
||||
float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += *(const float *) (x + i00*nb00);
|
||||
}
|
||||
const float mean = sum/ne00;
|
||||
|
||||
float variance = 0.0f;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const float v = *(const float *) (x + i00*nb00) - mean;
|
||||
*(float *) (y + i00*nb0) = v;
|
||||
variance += v * v;
|
||||
}
|
||||
variance /= ne00;
|
||||
|
||||
const float scale = 1.0f/sqrtf(variance + eps);
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
*(float *) (y + i00*nb0) *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32(
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
@@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32(
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
sum += (ggml_float)(x[i00] * x[i00]);
|
||||
const float xi = *(const float *) (x + i00*nb00);
|
||||
sum += (ggml_float)(xi * xi);
|
||||
}
|
||||
|
||||
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
|
||||
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
||||
|
||||
ggml_vec_scale_f32(ne00, y, scale);
|
||||
char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3;
|
||||
|
||||
if (nb00 == sizeof(float) && nb0 == sizeof(float)) {
|
||||
memcpy(y, x, ne00 * sizeof(float));
|
||||
ggml_vec_scale_f32(ne00, (float *) y, scale);
|
||||
} else {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const float xi = *(const float *) (x + i00*nb00);
|
||||
*(float *) (y + i00*nb0) = xi * scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5334,7 +5334,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return true;
|
||||
return ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
break;
|
||||
|
||||
@@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state {
|
||||
}
|
||||
};
|
||||
|
||||
struct vk_conv3d_pipeline_state {
|
||||
vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2,
|
||||
uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned)
|
||||
: s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {}
|
||||
|
||||
uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD;
|
||||
uint32_t aligned;
|
||||
|
||||
bool operator<(const vk_conv3d_pipeline_state &b) const {
|
||||
return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) <
|
||||
std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned);
|
||||
}
|
||||
};
|
||||
|
||||
struct vk_solve_tri_pipeline_state {
|
||||
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
|
||||
: N(N), K(K) {}
|
||||
@@ -777,6 +791,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
||||
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_back_f32;
|
||||
vk_pipeline pipeline_acc_f32;
|
||||
vk_pipeline pipeline_set_f32;
|
||||
|
||||
@@ -801,14 +816,10 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64;
|
||||
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
|
||||
vk_pipeline pipeline_scale_f32;
|
||||
vk_pipeline pipeline_sqr_f32;
|
||||
vk_pipeline pipeline_sqrt_f32;
|
||||
vk_pipeline pipeline_sin_f32;
|
||||
vk_pipeline pipeline_cos_f32;
|
||||
vk_pipeline pipeline_log[2];
|
||||
vk_pipeline pipeline_tri[2];
|
||||
vk_pipeline pipeline_diag[2];
|
||||
vk_pipeline pipeline_clamp_f32;
|
||||
vk_pipeline pipeline_clamp[2];
|
||||
vk_pipeline pipeline_pad_f32;
|
||||
vk_pipeline pipeline_roll_f32;
|
||||
vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
|
||||
@@ -840,6 +851,10 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_gelu_quick[2];
|
||||
vk_pipeline pipeline_silu[2];
|
||||
vk_pipeline pipeline_relu[2];
|
||||
vk_pipeline pipeline_sqr[2];
|
||||
vk_pipeline pipeline_sqrt[2];
|
||||
vk_pipeline pipeline_sin[2];
|
||||
vk_pipeline pipeline_cos[2];
|
||||
vk_pipeline pipeline_xielu[2];
|
||||
vk_pipeline pipeline_neg[2];
|
||||
vk_pipeline pipeline_tanh[2];
|
||||
@@ -871,7 +886,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_geglu_erf[2];
|
||||
vk_pipeline pipeline_geglu_quick[2];
|
||||
|
||||
vk_pipeline pipeline_leaky_relu_f32;
|
||||
vk_pipeline pipeline_leaky_relu[2];
|
||||
vk_pipeline pipeline_silu_back_f32;
|
||||
vk_pipeline pipeline_diag_mask_inf_f32;
|
||||
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
||||
@@ -924,6 +939,8 @@ struct vk_device_struct {
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f32[CONV_SHAPE_COUNT];
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT];
|
||||
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
|
||||
|
||||
@@ -1669,6 +1686,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
|
||||
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
||||
}
|
||||
|
||||
struct vk_op_conv3d_push_constants {
|
||||
uint32_t OC;
|
||||
uint32_t IC;
|
||||
uint32_t N;
|
||||
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t OW;
|
||||
uint32_t OH;
|
||||
uint32_t OD;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
|
||||
uint32_t nb1;
|
||||
uint32_t nb2;
|
||||
uint32_t nb3;
|
||||
|
||||
uint32_t OWmp; uint32_t OWL;
|
||||
uint32_t OWOHmp; uint32_t OWOHL;
|
||||
uint32_t OWOHODmp; uint32_t OWOHODL;
|
||||
};
|
||||
|
||||
template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) {
|
||||
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
|
||||
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
|
||||
init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL);
|
||||
}
|
||||
|
||||
struct vk_op_conv2d_dw_push_constants {
|
||||
uint32_t ne;
|
||||
uint32_t batches;
|
||||
@@ -4074,19 +4126,35 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
#endif
|
||||
|
||||
auto const &ggml_vk_mul_mm_spec = [](std::vector<uint32_t> spec, bool aligned) {
|
||||
spec.push_back(aligned ? 1u : 0u);
|
||||
return spec;
|
||||
};
|
||||
|
||||
const int mul_mat_id_param_count = 5;
|
||||
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector<uint32_t> spec, bool aligned, bool mul_mat_id) {
|
||||
if (mul_mat_id && spec.size() > 5) {
|
||||
spec.insert(spec.begin() + 5, aligned ? 1u : 0u);
|
||||
} else {
|
||||
spec.push_back(aligned ? 1u : 0u);
|
||||
}
|
||||
if (mul_mat_id && spec.size() == 6) {
|
||||
spec.push_back(32);
|
||||
}
|
||||
return spec;
|
||||
};
|
||||
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
@@ -4161,17 +4229,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \
|
||||
|
||||
// Create 2 variants, {f16,f32} accumulator
|
||||
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
@@ -4284,32 +4352,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
// bf16 scalar path promotes to f32, no dot2 variant
|
||||
#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l_int[TYPE]) { \
|
||||
@@ -4474,17 +4542,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _l[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _m[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
if (device->mul_mat ## ID ## _s[TYPE]) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
||||
|
||||
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
||||
if (device->mul_mat ## ID ## _l_int[TYPE]) \
|
||||
@@ -4879,6 +4947,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
|
||||
@@ -4903,7 +4972,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
||||
@@ -5023,11 +5092,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
@@ -5037,8 +5101,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -5058,6 +5120,12 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
CREATE_UNARY(gelu_quick)
|
||||
CREATE_UNARY(silu)
|
||||
CREATE_UNARY(relu)
|
||||
CREATE_UNARY(sqr)
|
||||
CREATE_UNARY(sqrt)
|
||||
CREATE_UNARY(sin)
|
||||
CREATE_UNARY(cos)
|
||||
CREATE_UNARY(clamp)
|
||||
CREATE_UNARY(leaky_relu)
|
||||
CREATE_UNARY(xielu)
|
||||
CREATE_UNARY(neg)
|
||||
CREATE_UNARY(tanh)
|
||||
@@ -5097,7 +5165,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
CREATE_GLU(geglu_quick)
|
||||
#undef CREATE_GLU
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
||||
@@ -5314,7 +5381,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
// conv2d, conv_transpose_2d
|
||||
// conv2d, conv_transpose_2d, conv3d
|
||||
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
||||
// smaller WG for the small-tile fallback gives more concurrent WGs per SM
|
||||
uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
|
||||
@@ -5377,8 +5444,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
|
||||
};
|
||||
|
||||
// coopmat1 needs to store the output through shared memory, so check up front
|
||||
// whether it'll fit and disable it before applying coopmat1 parameters.
|
||||
// 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem
|
||||
// layout. cm1 needs Csh for output, so check before applying cm1 params.
|
||||
if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
|
||||
conv2d_use_cm1 = false;
|
||||
}
|
||||
@@ -5470,6 +5537,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
|
||||
}
|
||||
#undef CREATE_CONV
|
||||
#undef CREATE_CONVS
|
||||
|
||||
std::vector<uint32_t> conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD };
|
||||
#define CREATE_CONV3D(type_suffix, spv_suffix) \
|
||||
for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \
|
||||
const vk_conv3d_pipeline_state &state = c.first; \
|
||||
std::vector<uint32_t> spec_constants_cpy = conv3d_spec_constants; \
|
||||
spec_constants_cpy.push_back(state.s0); \
|
||||
spec_constants_cpy.push_back(state.s1); \
|
||||
spec_constants_cpy.push_back(state.s2); \
|
||||
spec_constants_cpy.push_back(state.p0); \
|
||||
spec_constants_cpy.push_back(state.p1); \
|
||||
spec_constants_cpy.push_back(state.p2); \
|
||||
spec_constants_cpy.push_back(state.d0); \
|
||||
spec_constants_cpy.push_back(state.d1); \
|
||||
spec_constants_cpy.push_back(state.d2); \
|
||||
spec_constants_cpy.push_back(state.KW); \
|
||||
spec_constants_cpy.push_back(state.KH); \
|
||||
spec_constants_cpy.push_back(state.KD); \
|
||||
spec_constants_cpy.push_back(state.aligned); \
|
||||
spec_constants_cpy.push_back(conv2d_csh_store); \
|
||||
spec_constants_cpy.push_back(conv2d_WM); \
|
||||
spec_constants_cpy.push_back(conv2d_WN); \
|
||||
ggml_vk_create_pipeline( \
|
||||
device, c.second, "conv3d" #type_suffix, \
|
||||
conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \
|
||||
sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \
|
||||
}
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
CREATE_CONV3D(_f32, _cm2)
|
||||
CREATE_CONV3D(_f16_f32, _cm2)
|
||||
} else
|
||||
#endif
|
||||
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (conv2d_use_cm1) {
|
||||
CREATE_CONV3D(_f32, _cm1)
|
||||
CREATE_CONV3D(_f16_f32, _cm1)
|
||||
} else
|
||||
#endif
|
||||
if (conv2d_UNROLL) {
|
||||
CREATE_CONV3D(_f32, _unroll)
|
||||
CREATE_CONV3D(_f16_f32, _unroll)
|
||||
} else {
|
||||
CREATE_CONV3D(_f32, )
|
||||
CREATE_CONV3D(_f16_f32, )
|
||||
}
|
||||
#undef CREATE_CONV3D
|
||||
}
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
||||
@@ -10294,6 +10408,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_get_rows_f32[src0->type];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_get_rows_back_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ACC:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_acc_f32;
|
||||
@@ -10400,23 +10519,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SQR:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sqr_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SQRT:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sqrt_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_SIN:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_sin_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_COS:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_cos_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LOG:
|
||||
@@ -10438,8 +10561,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CLAMP:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_clamp_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_PAD:
|
||||
@@ -10807,8 +10931,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_leaky_relu_f32;
|
||||
if (src0->type == dst->type &&
|
||||
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
|
||||
return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_2D:
|
||||
@@ -10885,6 +11010,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_CONV_3D:
|
||||
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11);
|
||||
const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9);
|
||||
const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10);
|
||||
const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0];
|
||||
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ);
|
||||
|
||||
const uint32_t KW = (uint32_t)src0->ne[0];
|
||||
const uint32_t KH = (uint32_t)src0->ne[1];
|
||||
const uint32_t KD = (uint32_t)src0->ne[2];
|
||||
const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0);
|
||||
const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1);
|
||||
const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2);
|
||||
const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3);
|
||||
const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4);
|
||||
const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5);
|
||||
const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6);
|
||||
const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7);
|
||||
const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8);
|
||||
|
||||
const uint32_t CRS = IC * KW * KH * KD;
|
||||
const uint32_t BS_K = vk_conv_block_sizes[shape].K;
|
||||
const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
|
||||
const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
|
||||
const uint32_t aligned = ((OC % BS_K == 0) &&
|
||||
(CRS % BS_CRS == 0) &&
|
||||
(NPQ % BS_NPQ == 0)) ? 1u : 0u;
|
||||
|
||||
vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned);
|
||||
|
||||
std::map<vk_conv3d_pipeline_state, vk_pipeline> *pipelines = nullptr;
|
||||
if (src0->type == GGML_TYPE_F32) {
|
||||
pipelines = &ctx->device->pipeline_conv3d_f32[shape];
|
||||
} else if (src0->type == GGML_TYPE_F16) {
|
||||
pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape];
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
vk_pipeline pipeline = nullptr;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
|
||||
auto it = pipelines->find(conv3d_pipeline_state);
|
||||
if (it != pipelines->end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
(*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
|
||||
}
|
||||
}
|
||||
|
||||
return pipeline;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ADD1:
|
||||
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
return ctx->device->pipeline_add1_f16_f16;
|
||||
@@ -11135,6 +11315,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 };
|
||||
elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
GGML_ASSERT(0);
|
||||
break;
|
||||
@@ -11220,6 +11404,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
GGML_ABORT("invalid push constant type for CONV_2D");
|
||||
}
|
||||
break;
|
||||
case GGML_OP_CONV_3D:
|
||||
if constexpr (std::is_same_v<PC, vk_op_conv3d_push_constants>) {
|
||||
const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW;
|
||||
const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ);
|
||||
const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
|
||||
|
||||
elements = { pc.OC, NPQ_blocks, 1 };
|
||||
if (elements[1] > 512) {
|
||||
elements[2] = CEIL_DIV(elements[1], 512);
|
||||
elements[1] = 512;
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("invalid push constant type for CONV_3D");
|
||||
}
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_DIV:
|
||||
@@ -11236,6 +11435,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
case GGML_OP_TRI:
|
||||
case GGML_OP_DIAG:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_PAD:
|
||||
case GGML_OP_ROLL:
|
||||
case GGML_OP_REPEAT:
|
||||
@@ -11380,6 +11580,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
0,
|
||||
0.0f, 0.0f, 0,
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
@@ -12087,8 +12302,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
|
||||
static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
float * op_params = (float *)dst->op_params;
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = op_params[0];
|
||||
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
@@ -13118,6 +13335,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx,
|
||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
|
||||
const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
|
||||
vk_op_conv3d_push_constants p{};
|
||||
p.IC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 9));
|
||||
p.N = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 10));
|
||||
p.OC = static_cast<uint32_t>(ggml_get_op_params_i32(dst, 11));
|
||||
GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC);
|
||||
GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N);
|
||||
GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N);
|
||||
|
||||
p.IW = static_cast<uint32_t>(ne10);
|
||||
p.IH = static_cast<uint32_t>(ne11);
|
||||
p.ID = static_cast<uint32_t>(ne12);
|
||||
p.OW = static_cast<uint32_t>(ne0);
|
||||
p.OH = static_cast<uint32_t>(ne1);
|
||||
p.OD = static_cast<uint32_t>(ne2);
|
||||
|
||||
// the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the
|
||||
// total input element count must fit in a uint32.
|
||||
GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull);
|
||||
|
||||
p.nb01 = static_cast<uint32_t>(nb01 / nb00);
|
||||
p.nb02 = static_cast<uint32_t>(nb02 / nb00);
|
||||
p.nb03 = static_cast<uint32_t>(nb03 / nb00);
|
||||
|
||||
p.nb11 = static_cast<uint32_t>(nb11 / nb10);
|
||||
p.nb12 = static_cast<uint32_t>(nb12 / nb10);
|
||||
p.nb13 = static_cast<uint32_t>(nb13 / nb10);
|
||||
|
||||
p.nb1 = static_cast<uint32_t>(nb1 / nb0);
|
||||
p.nb2 = static_cast<uint32_t>(nb2 / nb0);
|
||||
p.nb3 = static_cast<uint32_t>(nb3 / nb0);
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p));
|
||||
}
|
||||
|
||||
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
vk_op_conv2d_dw_push_constants p{};
|
||||
p.ne = ggml_nelements(dst);
|
||||
@@ -13144,7 +13406,10 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
||||
const float * op_params = (const float *)dst->op_params;
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
|
||||
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
|
||||
p.param1 = op_params[0];
|
||||
|
||||
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p));
|
||||
}
|
||||
|
||||
#ifdef GGML_VULKAN_RUN_TESTS
|
||||
@@ -14247,6 +14512,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
if (ctx->num_additional_fused_ops) {
|
||||
@@ -14515,6 +14784,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||
ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_3D:
|
||||
ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node);
|
||||
|
||||
break;
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
|
||||
@@ -16964,6 +17237,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
return false;
|
||||
}
|
||||
}
|
||||
case GGML_OP_GET_ROWS_BACK:
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
switch (op->type) {
|
||||
@@ -17060,12 +17335,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_RMS_NORM:
|
||||
return true;
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return ggml_is_contiguous(op->src[0]);
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return ggml_is_contiguous_rows(op->src[0]) &&
|
||||
op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
@@ -17084,8 +17358,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->type == op->src[0]->type;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||
@@ -17285,6 +17560,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
ggml_is_contiguous(op));
|
||||
}
|
||||
case GGML_OP_CONV_3D:
|
||||
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(op->src[0]) &&
|
||||
ggml_is_contiguous(op->src[1]) &&
|
||||
ggml_is_contiguous(op);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -18128,6 +18410,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
const int32_t d0 = tensor->op_params[4];
|
||||
const int32_t d1 = tensor->op_params[5];
|
||||
tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
|
||||
} else if (tensor->op == GGML_OP_CONV_3D) {
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t s1 = tensor->op_params[1];
|
||||
const int32_t s2 = tensor->op_params[2];
|
||||
const int32_t p0 = tensor->op_params[3];
|
||||
const int32_t p1 = tensor->op_params[4];
|
||||
const int32_t p2 = tensor->op_params[5];
|
||||
const int32_t d0 = tensor->op_params[6];
|
||||
const int32_t d1 = tensor->op_params[7];
|
||||
const int32_t d2 = tensor->op_params[8];
|
||||
const int32_t IC = tensor->op_params[9];
|
||||
const int32_t N = tensor->op_params[10];
|
||||
const int32_t OC = tensor->op_params[11];
|
||||
tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC);
|
||||
} else if (tensor->op == GGML_OP_CONV_2D_DW) {
|
||||
const int32_t s0 = tensor->op_params[0];
|
||||
const int32_t s1 = tensor->op_params[1];
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
|
||||
}
|
||||
@@ -0,0 +1,431 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#ifdef COOPMAT2
|
||||
#extension GL_NV_cooperative_matrix2 : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
|
||||
layout(binding = 0) readonly buffer A {
|
||||
A_TYPE knl_data[];
|
||||
}; // src0 - kernel: [KW, KH, KD, IC*OC]
|
||||
|
||||
layout(binding = 1) readonly buffer B {
|
||||
B_TYPE src_data[];
|
||||
}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format
|
||||
|
||||
layout(binding = 2) writeonly buffer D {
|
||||
D_TYPE dst_data[];
|
||||
}; // dst - result: [OW, OH, OD, OC*N]
|
||||
|
||||
layout(push_constant) uniform parameter {
|
||||
// I/O channels, batch size
|
||||
uint32_t OC;
|
||||
uint32_t IC;
|
||||
uint32_t N;
|
||||
|
||||
// Tensor spatial sizes: input, output
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t OW;
|
||||
uint32_t OH;
|
||||
uint32_t OD;
|
||||
|
||||
// Strides in elements
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
|
||||
uint32_t nb1;
|
||||
uint32_t nb2;
|
||||
uint32_t nb3;
|
||||
|
||||
// fastdiv helper values
|
||||
uint32_t OWmp; uint32_t OWL;
|
||||
uint32_t OWOHmp; uint32_t OWOHL;
|
||||
uint32_t OWOHODmp; uint32_t OWOHODL;
|
||||
}
|
||||
|
||||
p;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
// Blocktile sizes
|
||||
layout(constant_id = 1) const uint BS_K = 128;
|
||||
layout(constant_id = 2) const uint BS_CRS = 16;
|
||||
layout(constant_id = 3) const uint BS_NPQ = 128;
|
||||
// Thread-tile sizes
|
||||
layout(constant_id = 4) const uint TS_K = 8;
|
||||
layout(constant_id = 5) const uint SHMEM_PAD = 4;
|
||||
// Stride, padding, dilation
|
||||
layout(constant_id = 6) const uint s0 = 1;
|
||||
layout(constant_id = 7) const uint s1 = 1;
|
||||
layout(constant_id = 8) const uint s2 = 1;
|
||||
layout(constant_id = 9) const uint p0 = 0;
|
||||
layout(constant_id = 10) const uint p1 = 0;
|
||||
layout(constant_id = 11) const uint p2 = 0;
|
||||
layout(constant_id = 12) const uint d0 = 1;
|
||||
layout(constant_id = 13) const uint d1 = 1;
|
||||
layout(constant_id = 14) const uint d2 = 1;
|
||||
// Kernel spatial sizes
|
||||
layout(constant_id = 15) const uint KW = 1;
|
||||
layout(constant_id = 16) const uint KH = 1;
|
||||
layout(constant_id = 17) const uint KD = 1;
|
||||
// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned)
|
||||
layout(constant_id = 18) const uint aligned = 0;
|
||||
// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this.
|
||||
layout(constant_id = 19) const uint csh_store = 0;
|
||||
|
||||
#ifdef COOPMAT
|
||||
// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of
|
||||
// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN ==
|
||||
// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size.
|
||||
layout(constant_id = 20) const uint WM = 32;
|
||||
layout(constant_id = 21) const uint WN = 32;
|
||||
const uint TM = 16;
|
||||
const uint TN = 16;
|
||||
const uint TK = 16;
|
||||
const uint cms_per_row = WM / TM;
|
||||
const uint cms_per_col = WN / TN;
|
||||
const uint warps_M = BS_K / WM;
|
||||
const uint warps_N = BS_NPQ / WN;
|
||||
#endif
|
||||
|
||||
// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction
|
||||
const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0);
|
||||
|
||||
uint32_t tid = gl_LocalInvocationID.x;
|
||||
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
|
||||
|
||||
uint splitWork(uint work_size, uint block_size) {
|
||||
return (block_size + work_size - 1) / block_size;
|
||||
}
|
||||
|
||||
uint32_t K = p.OC;
|
||||
uint32_t CRS = p.IC * KD * KH * KW;
|
||||
uint32_t NPQ = p.N * p.OD * p.OH * p.OW;
|
||||
|
||||
// Number of blocktiles per input
|
||||
uint32_t NB_CRS = splitWork(CRS, BS_CRS);
|
||||
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
#define SHMEM_TYPE float16_t
|
||||
#else
|
||||
#define SHMEM_TYPE float
|
||||
#endif
|
||||
|
||||
const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
|
||||
const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
|
||||
|
||||
const uint32_t Ash_len = BS_K * Ash_stride;
|
||||
const uint32_t Bsh_len = BS_CRS * Bsh_stride;
|
||||
|
||||
shared SHMEM_TYPE Ash[Ash_len]; // K x CRS
|
||||
shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ
|
||||
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
// stage matC through shmem so global stores are row-major (NPQ-contiguous)
|
||||
const uint32_t Csh_stride = BS_NPQ;
|
||||
#ifdef COOPMAT
|
||||
const uint32_t Csh_len = BS_K * Csh_stride;
|
||||
#else
|
||||
const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1;
|
||||
#endif
|
||||
shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ
|
||||
#endif
|
||||
|
||||
// Threadtile sizes
|
||||
const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
|
||||
|
||||
// Number of threadtiles per blocktile
|
||||
const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
|
||||
|
||||
/*
|
||||
Compute
|
||||
KxCRS @ CRSxNPQ = K x NPQ
|
||||
K=OC
|
||||
C=IC
|
||||
D,R,S=KD,KH,KW
|
||||
Z,P,Q=OD,OH,OW
|
||||
*/
|
||||
|
||||
uint32_t B_idx_K = gl_WorkGroupID.x;
|
||||
uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
|
||||
|
||||
uint32_t T_y = tid / NT_NPQ;
|
||||
uint32_t T_x = tid % NT_NPQ;
|
||||
|
||||
uint32_t Ar = tid / BS_CRS;
|
||||
uint32_t Ac = tid % BS_CRS;
|
||||
const uint32_t ArpWg = WG_SIZE / BS_CRS;
|
||||
|
||||
uint32_t Br = tid / BS_NPQ;
|
||||
uint32_t Bc = tid % BS_NPQ;
|
||||
const uint32_t BrpWg = WG_SIZE / BS_NPQ;
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) {
|
||||
const uint32_t KHKW = KH * KW;
|
||||
const uint32_t KDKHKW = KD * KHKW;
|
||||
ic = crs_idx / KDKHKW;
|
||||
uint32_t rem = crs_idx - ic * KDKHKW;
|
||||
kd = rem / KHKW;
|
||||
rem = rem - kd * KHKW;
|
||||
kh = rem / KW;
|
||||
kw = rem - kh * KW;
|
||||
}
|
||||
|
||||
void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) {
|
||||
const uint32_t OWOH = p.OW * p.OH;
|
||||
n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL);
|
||||
uint32_t rem = npq_idx - n * p.OD * OWOH;
|
||||
od = fastdiv(rem, p.OWOHmp, p.OWOHL);
|
||||
rem = rem - od * OWOH;
|
||||
oh = fastdiv(rem, p.OWmp, p.OWL);
|
||||
ow = rem - oh * p.OW;
|
||||
}
|
||||
|
||||
#ifdef COOPMAT2
|
||||
#define ACC_TYPE float16_t
|
||||
|
||||
ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
|
||||
{
|
||||
uint32_t K_idx = B_idx_K * BS_K + r;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
if (B_idx_NPQ * BS_NPQ >= NPQ) {
|
||||
return;
|
||||
}
|
||||
|
||||
#ifdef COOPMAT2
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
|
||||
matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
|
||||
#elif defined(COOPMAT)
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
|
||||
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
|
||||
sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
|
||||
}
|
||||
const uint warp_r = gl_SubgroupID / warps_N;
|
||||
const uint warp_c = gl_SubgroupID % warps_N;
|
||||
#else
|
||||
float regC[TS_K][TS_NPQ];
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regC[T_ly][T_lx] = 0.0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
/* Advance block in CRS dim */
|
||||
[[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
|
||||
uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac;
|
||||
uint32_t IC_idx_a;
|
||||
uint32_t KD_idx_a;
|
||||
uint32_t KH_idx_a;
|
||||
uint32_t KW_idx_a;
|
||||
split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a);
|
||||
|
||||
/* Load kernel to A_block: (BS_K x BS_CRS)*/
|
||||
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
|
||||
uint32_t B_ly = r_offset + Ar;
|
||||
uint32_t B_lx = Ac;
|
||||
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
|
||||
uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03;
|
||||
if (aligned == 0) {
|
||||
knl_idx = min(knl_idx, K * CRS - 1);
|
||||
}
|
||||
float val = knl_data[knl_idx];
|
||||
if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) {
|
||||
val = 0.0;
|
||||
}
|
||||
Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
|
||||
}
|
||||
/* Load input to B_block: (BS_CRS x BS_NPQ) */
|
||||
UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
|
||||
uint32_t B_ly = r_offset + Br; /* Row index of B block */
|
||||
uint32_t B_lx = Bc;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
|
||||
uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly;
|
||||
uint32_t IC_idx_b;
|
||||
uint32_t KD_idx_b;
|
||||
uint32_t KH_idx_b;
|
||||
uint32_t KW_idx_b;
|
||||
split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b);
|
||||
|
||||
uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2;
|
||||
uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
|
||||
uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
|
||||
|
||||
uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13;
|
||||
// skip clamp when address can't go OOB
|
||||
if (aligned == 0 || !dhw_in_bounds) {
|
||||
src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1);
|
||||
}
|
||||
float val = src_data[src_idx];
|
||||
bool oob = false;
|
||||
if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) {
|
||||
oob = true;
|
||||
}
|
||||
// also catches lower-bound underflow (idx wraps to 0x80000000+)
|
||||
if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) {
|
||||
oob = true;
|
||||
}
|
||||
if (oob) {
|
||||
val = 0.0;
|
||||
}
|
||||
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
|
||||
}
|
||||
barrier();
|
||||
#ifdef COOPMAT2
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
|
||||
|
||||
coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
matC = coopMatMulAdd(matA, matB, matC);
|
||||
#elif defined(COOPMAT)
|
||||
// each subgroup multiplies its grid of fragments per TK-sized CRS chunk
|
||||
[[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row];
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK;
|
||||
coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
|
||||
const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN;
|
||||
coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (T_y * TS_K < K) {
|
||||
UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
|
||||
float regA[TS_K];
|
||||
float regB[TS_NPQ];
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
|
||||
}
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
|
||||
}
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
barrier();
|
||||
}
|
||||
/* Save C* */
|
||||
#if defined(COOPMAT2) || defined(COOPMAT)
|
||||
// stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores
|
||||
#ifdef COOPMAT
|
||||
const bool use_staged_store = true;
|
||||
#else
|
||||
const bool use_staged_store = (csh_store != 0);
|
||||
#endif
|
||||
if (use_staged_store) {
|
||||
#ifdef COOPMAT
|
||||
// cm1: each subgroup stores its fragment grid into its Csh slot
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN;
|
||||
coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
}
|
||||
}
|
||||
#else
|
||||
coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor);
|
||||
#endif
|
||||
barrier();
|
||||
|
||||
// cooperative shmem->global: WG threads spread across BS_NPQ (the
|
||||
// contiguous direction of dst), each iter covers store_rows_per_iter K-rows
|
||||
const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ;
|
||||
const uint32_t store_iters = BS_K / store_rows_per_iter;
|
||||
const uint32_t k_thread_offset = tid / BS_NPQ;
|
||||
const uint32_t npq_thread = tid % BS_NPQ;
|
||||
[[unroll]] for (uint32_t i = 0; i < store_iters; i++) {
|
||||
uint32_t k_local = i * store_rows_per_iter + k_thread_offset;
|
||||
uint32_t K_idx = B_idx_K * BS_K + k_local;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#ifdef COOPMAT2
|
||||
else {
|
||||
coopMatPerElementNV(matC, matC, perElemOpStore);
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
if (T_y * TS_K < K) {
|
||||
for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
|
||||
for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
|
||||
uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly;
|
||||
uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
|
||||
uint32_t N_idx;
|
||||
uint32_t OD_idx;
|
||||
uint32_t OH_idx;
|
||||
uint32_t OW_idx;
|
||||
split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx);
|
||||
uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3;
|
||||
if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) {
|
||||
dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
|
||||
}
|
||||
@@ -463,6 +463,7 @@ void main() {
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(Sf[r][c]));
|
||||
}
|
||||
rowmaxf += FATTN_KQ_MAX_OFFSET;
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// M = max(rowmax, Mold)
|
||||
|
||||
@@ -352,6 +352,7 @@ void main() {
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
|
||||
}
|
||||
rowmaxf += FATTN_KQ_MAX_OFFSET;
|
||||
float Moldf = Mf[r];
|
||||
|
||||
// Compute max across the row
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_binary_head.glsl"
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint col = gl_GlobalInvocationID.x;
|
||||
|
||||
if (col >= p.ne20) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) {
|
||||
float sum = 0.0f;
|
||||
for (uint i = 0; i < p.ne10; ++i) {
|
||||
if (data_b[get_boffset() + i*p.nb10] == int(row)) {
|
||||
sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00];
|
||||
}
|
||||
}
|
||||
|
||||
data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum;
|
||||
}
|
||||
}
|
||||
@@ -14,16 +14,13 @@ void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint i3 = row / (p.ne11 * p.ne12);
|
||||
const uint i3_offset = i3 * p.ne12 * p.ne11;
|
||||
const uint i2 = (row - i3_offset) / p.ne11;
|
||||
const uint i2_offset = i2 * p.ne11;
|
||||
const uint i1 = row - i3_offset - i2_offset;
|
||||
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
|
||||
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]);
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]);
|
||||
sum[tid] += xi * xi;
|
||||
}
|
||||
|
||||
@@ -39,6 +36,6 @@ void main() {
|
||||
const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1));
|
||||
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]));
|
||||
data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float val = float(data_a[i]);
|
||||
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
|
||||
}
|
||||
@@ -38,17 +38,7 @@
|
||||
#define LOAD_VEC_B 1
|
||||
#endif
|
||||
|
||||
// Load 2 values at once without affecting index calculations through LOAD_VEC
|
||||
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
|
||||
#define LOAD_VEC_BATCH_A 2
|
||||
#else
|
||||
#define LOAD_VEC_BATCH_A 1
|
||||
#endif
|
||||
#if !defined(ALIGNED)
|
||||
#define LOAD_VEC_BATCH_B 2
|
||||
#else
|
||||
#define LOAD_VEC_BATCH_B 1
|
||||
#endif
|
||||
layout (constant_id = 11) const uint ALIGNED = 0;
|
||||
|
||||
#if !defined(TO_FLOAT_TYPE)
|
||||
#define TO_FLOAT_TYPE FLOAT_TYPE
|
||||
@@ -57,6 +47,13 @@
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#if defined(DATA_A_F32)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];};
|
||||
#elif defined(DATA_A_F16)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];};
|
||||
#elif defined(DATA_A_BF16)
|
||||
layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];};
|
||||
#endif
|
||||
#if defined(A_TYPE_PACKED16)
|
||||
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||
#endif
|
||||
@@ -65,6 +62,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
@@ -194,13 +192,23 @@ void main() {
|
||||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
|
||||
const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1;
|
||||
const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2;
|
||||
#else
|
||||
const uint LOAD_VEC_A_EFF = LOAD_VEC_A;
|
||||
const uint LOAD_VEC_BATCH_A = 1;
|
||||
#endif
|
||||
const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1;
|
||||
const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2;
|
||||
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B);
|
||||
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
@@ -239,15 +247,15 @@ void main() {
|
||||
|
||||
uint pos_a =
|
||||
#ifdef MUL_MAT_ID
|
||||
expert_idx * (p.batch_stride_a / LOAD_VEC_A) +
|
||||
expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) +
|
||||
#else
|
||||
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) +
|
||||
batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) +
|
||||
#endif
|
||||
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
|
||||
(ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF;
|
||||
#ifdef MUL_MAT_ID
|
||||
uint pos_b = 0;
|
||||
#else
|
||||
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
|
||||
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF;
|
||||
#endif
|
||||
|
||||
#ifdef COOPMAT
|
||||
@@ -287,8 +295,8 @@ void main() {
|
||||
|
||||
barrier();
|
||||
|
||||
pos_a += BK / LOAD_VEC_A;
|
||||
pos_b += BK / LOAD_VEC_B;
|
||||
pos_a += BK / LOAD_VEC_A_EFF;
|
||||
pos_b += BK / LOAD_VEC_B_EFF;
|
||||
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint i = 0; i < BK; i += TK) {
|
||||
|
||||
@@ -36,6 +36,7 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit
|
||||
layout (constant_id = 4) const bool enable_smaller_matrices = false;
|
||||
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
|
||||
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
|
||||
layout (constant_id = 5) const uint ALIGNED = 0;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
@@ -111,7 +112,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
};
|
||||
|
||||
uint _ne1;
|
||||
layout (constant_id = 5) const uint subgroup_size = 32;
|
||||
layout (constant_id = 6) const uint subgroup_size = 32;
|
||||
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
|
||||
|
||||
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
@@ -297,12 +298,12 @@ void main() {
|
||||
|
||||
// Hint to the compiler that values are aligned (want 16B alignment).
|
||||
// Quants are always block-aligned, no alignment needed.
|
||||
#if ALIGNED
|
||||
if (ALIGNED != 0) {
|
||||
#if QUANT_K == 1
|
||||
stride_a &= ~7;
|
||||
#endif
|
||||
stride_b &= ~7;
|
||||
stride_a &= ~7;
|
||||
#endif
|
||||
stride_b &= ~7;
|
||||
}
|
||||
|
||||
// Create layouts for both clamped and unclamped accesses
|
||||
tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
|
||||
|
||||
@@ -1,50 +1,57 @@
|
||||
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
buf_a[buf_idx + 3] = aa[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
buf_a[buf_idx + 3] = aa[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx],
|
||||
data_a[idx + 1]);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx],
|
||||
data_a_scalar[idx + 1]);
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_BF16)
|
||||
#if LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]),
|
||||
TO_FLOAT_TYPE(data_a[idx + 1]));
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_a_scalar[idx + 1]));
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4;
|
||||
@@ -526,75 +533,85 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
|
||||
#if !defined(MUL_MAT_ID)
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_B == 4
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
if (ALIGNED != 0) {
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_BATCH_B == 2
|
||||
const uint idx = pos_b + col * p.stride_b + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
|
||||
} else if (idx_n < p.N && block + row * 2 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
if (ALIGNED != 0) {
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
return;
|
||||
}
|
||||
#elif LOAD_VEC_B == 4
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
if (ALIGNED != 0) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_BATCH_B == 2
|
||||
const uint row_i = ic * BN + col;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]),
|
||||
TO_FLOAT_TYPE(data_b_scalar[idx + 1]));
|
||||
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPEV2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_head.glsl"
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#define BLOCK_SIZE 512
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
shared vec2 sum[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
const uint a_base = get_aoffset() + src0_idx(row * p.ne00);
|
||||
const uint d_base = get_doffset() + dst_idx(row * p.ne10);
|
||||
|
||||
sum[tid] = vec2(0.0f, 0.0f);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
const float xi = float(data_a[row*p.KX + col]);
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
const float xi = float(data_a[a_base + i0*p.nb00]);
|
||||
sum[tid].x += xi;
|
||||
sum[tid].y += xi * xi;
|
||||
}
|
||||
@@ -34,11 +34,11 @@ void main() {
|
||||
barrier();
|
||||
}
|
||||
|
||||
const float mean = sum[0].x / p.KX;
|
||||
const float var = sum[0].y / p.KX - mean * mean;
|
||||
const float mean = sum[0].x / p.ne00;
|
||||
const float var = sum[0].y / p.ne00 - mean * mean;
|
||||
const float inv_std = inversesqrt(var + p.param1);
|
||||
|
||||
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||
data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
|
||||
[[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) {
|
||||
data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
#version 450
|
||||
|
||||
#include "types.glsl"
|
||||
#include "generic_unary_head.glsl"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
|
||||
}
|
||||
@@ -17,6 +17,30 @@ float op_neg(float x) {
|
||||
return -x;
|
||||
}
|
||||
|
||||
float op_sqr(float x) {
|
||||
return x * x;
|
||||
}
|
||||
|
||||
float op_sqrt(float x) {
|
||||
return sqrt(x);
|
||||
}
|
||||
|
||||
float op_sin(float x) {
|
||||
return sin(x);
|
||||
}
|
||||
|
||||
float op_cos(float x) {
|
||||
return cos(x);
|
||||
}
|
||||
|
||||
float op_clamp(float x) {
|
||||
return clamp(x, p.param1, p.param2);
|
||||
}
|
||||
|
||||
float op_leaky_relu(float x) {
|
||||
return max(x, 0.0f) + min(x, 0.0f) * p.param1;
|
||||
}
|
||||
|
||||
float op_step(float x) {
|
||||
return x >= 0.0f ? 1.0f : 0.0f;
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <future>
|
||||
#include <queue>
|
||||
#include <condition_variable>
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
@@ -34,6 +35,9 @@
|
||||
|
||||
std::mutex lock;
|
||||
std::vector<std::pair<std::string, std::string>> shader_fnames;
|
||||
// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the
|
||||
// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393)
|
||||
static std::atomic<bool> compile_failed{false};
|
||||
std::locale c_locale("C");
|
||||
|
||||
std::string GLSLC = "glslc";
|
||||
@@ -78,7 +82,7 @@ enum MatMulIdType {
|
||||
|
||||
namespace {
|
||||
|
||||
void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
int execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
#ifdef _WIN32
|
||||
HANDLE stdout_read, stdout_write;
|
||||
HANDLE stderr_read, stderr_write;
|
||||
@@ -127,8 +131,11 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
|
||||
CloseHandle(stdout_read);
|
||||
CloseHandle(stderr_read);
|
||||
WaitForSingleObject(pi.hProcess, INFINITE);
|
||||
DWORD exit_code = 1;
|
||||
GetExitCodeProcess(pi.hProcess, &exit_code);
|
||||
CloseHandle(pi.hProcess);
|
||||
CloseHandle(pi.hThread);
|
||||
return (int)exit_code;
|
||||
#else
|
||||
int stdout_pipe[2];
|
||||
int stderr_pipe[2];
|
||||
@@ -175,7 +182,9 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str,
|
||||
|
||||
close(stdout_pipe[0]);
|
||||
close(stderr_pipe[0]);
|
||||
waitpid(pid, nullptr, 0);
|
||||
int status = 0;
|
||||
waitpid(pid, &status, 0);
|
||||
return WIFEXITED(status) ? WEXITSTATUS(status) : -1;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -372,13 +381,14 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
|
||||
execute_command(cmd, stdout_str, stderr_str);
|
||||
if (!stderr_str.empty()) {
|
||||
std::cerr << "cannot compile " << name << "\n\n";
|
||||
int exit_code = execute_command(cmd, stdout_str, stderr_str);
|
||||
if (exit_code != 0 || !stderr_str.empty()) {
|
||||
std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n";
|
||||
for (const auto& part : cmd) {
|
||||
std::cerr << part << " ";
|
||||
}
|
||||
std::cerr << "\n\n" << stderr_str << std::endl;
|
||||
compile_failed = true;
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -398,6 +408,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p
|
||||
shader_fnames.push_back(std::make_pair(name, out_path));
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
|
||||
compile_failed = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -539,11 +550,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
};
|
||||
|
||||
// Shaders with f16 B_TYPE
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
// bf16
|
||||
{
|
||||
@@ -565,8 +574,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
#endif
|
||||
{
|
||||
if (!dot2) {
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -583,8 +591,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
}
|
||||
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
// For unaligned, load one at a time for f32/f16, or two at a time for quants
|
||||
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
|
||||
// For aligned matmul loads
|
||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||
|
||||
@@ -597,13 +603,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
|
||||
|
||||
// don't generate f32 variants for coopmat2
|
||||
if (!coopmat2) {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
if (tname != "f16" && tname != "f32") {
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
@@ -850,21 +854,12 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}});
|
||||
string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}});
|
||||
|
||||
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}});
|
||||
@@ -891,6 +886,18 @@ void process_shaders() {
|
||||
string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}});
|
||||
string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}});
|
||||
string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}});
|
||||
string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}});
|
||||
string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}});
|
||||
string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}});
|
||||
string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}});
|
||||
string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}});
|
||||
string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}});
|
||||
string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}});
|
||||
string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}});
|
||||
string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}});
|
||||
string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}});
|
||||
string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}});
|
||||
string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}});
|
||||
string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}});
|
||||
string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}});
|
||||
string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}});
|
||||
@@ -948,7 +955,6 @@ void process_shaders() {
|
||||
string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
@@ -1060,6 +1066,31 @@ void process_shaders() {
|
||||
}
|
||||
}
|
||||
|
||||
for (auto unroll : {false, true}) {
|
||||
for (auto a_f16 : {false, true}) {
|
||||
std::map<std::string, std::string> defines = {
|
||||
{"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
|
||||
{"UNROLL", unroll ? "[[unroll]]" : ""},
|
||||
};
|
||||
std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32";
|
||||
string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines);
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (unroll) {
|
||||
auto cm2_defines = defines;
|
||||
cm2_defines["COOPMAT2"] = "1";
|
||||
string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (unroll) {
|
||||
auto cm1_defines = defines;
|
||||
cm1_defines["COOPMAT"] = "1";
|
||||
string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
@@ -1251,6 +1282,11 @@ int main(int argc, char** argv) {
|
||||
|
||||
process_shaders();
|
||||
|
||||
if (compile_failed) {
|
||||
std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
write_output_files();
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
|
||||
@@ -4270,7 +4270,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
|
||||
supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0);
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
|
||||
|
||||
+14
-4
@@ -190,7 +190,15 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
|
||||
auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs);
|
||||
auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs);
|
||||
|
||||
bx = ggml_concat(ctx0, conv, bx, 0);
|
||||
// causal prepends the state, non-causal pads symmetrically for a centered window
|
||||
if (hparams.causal_attn) {
|
||||
bx = ggml_concat(ctx0, conv, bx, 0);
|
||||
} else {
|
||||
const int64_t pad = (hparams.n_shortconv_l_cache - 1) / 2;
|
||||
auto * left = ggml_cont(ctx0,
|
||||
ggml_view_3d(ctx0, conv, pad, hparams.n_embd, n_seqs, conv->nb[1], conv->nb[2], (d_conv - pad) * conv->nb[0]));
|
||||
bx = ggml_pad_ext(ctx0, ggml_concat(ctx0, left, bx, 0), 0, pad, 0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
GGML_ASSERT(bx->ne[0] > conv->ne[0]);
|
||||
|
||||
// last d_conv columns is a new conv state
|
||||
@@ -266,10 +274,12 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
cur = build_lora_mm(model.output, cur, model.output_s);
|
||||
cb(cur, "result_output", -1);
|
||||
if (!cparams.embeddings) {
|
||||
cur = build_lora_mm(model.output, cur, model.output_s);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
res->t_logits = cur;
|
||||
res->t_logits = cur;
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
@@ -3298,21 +3298,29 @@ struct test_norm : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const bool v; // whether a is a non-contiguous view
|
||||
const float eps;
|
||||
const bool noncontig_rows;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, v, eps);
|
||||
return VARS_TO_STR5(type, ne, v, eps, noncontig_rows);
|
||||
}
|
||||
|
||||
test_norm(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 5, 4, 3},
|
||||
bool v = false,
|
||||
float eps = 1e-6f)
|
||||
: type(type), ne(ne), v(v), eps(eps) {}
|
||||
float eps = 1e-6f,
|
||||
bool noncontig_rows = false)
|
||||
: type(type), ne(ne), v(v), eps(eps), noncontig_rows(noncontig_rows) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
const std::array<int64_t, 4> ne_a = noncontig_rows ?
|
||||
std::array<int64_t, 4>{ ne[1], ne[0], ne[2], ne[3] } : ne;
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
if (noncontig_rows) {
|
||||
a = ggml_permute(ctx, a, 1, 0, 2, 3);
|
||||
ggml_set_name(a, "permuted a");
|
||||
}
|
||||
if (v) {
|
||||
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
|
||||
ggml_set_name(a, "view of a");
|
||||
@@ -6193,21 +6201,29 @@ struct test_l2_norm : public test_case {
|
||||
const std::array<int64_t, 4> ne;
|
||||
const float eps;
|
||||
bool v;
|
||||
bool noncontig_rows;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR4(type, ne, eps, v);
|
||||
return VARS_TO_STR5(type, ne, eps, v, noncontig_rows);
|
||||
}
|
||||
|
||||
test_l2_norm(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {64, 64, 320, 1},
|
||||
float eps = 1e-12f,
|
||||
bool v = false)
|
||||
: type(type), ne(ne), eps(eps), v(v) {}
|
||||
bool v = false,
|
||||
bool noncontig_rows = false)
|
||||
: type(type), ne(ne), eps(eps), v(v), noncontig_rows(noncontig_rows) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
const std::array<int64_t, 4> ne_a = noncontig_rows ?
|
||||
std::array<int64_t, 4>{ ne[1], ne[0], ne[2], ne[3] } : ne;
|
||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
if (noncontig_rows) {
|
||||
a = ggml_permute(ctx, a, 1, 0, 2, 3);
|
||||
ggml_set_name(a, "permuted a");
|
||||
}
|
||||
if (v) {
|
||||
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
|
||||
ggml_set_name(a, "view of a");
|
||||
@@ -8282,9 +8298,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
|
||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps));
|
||||
}
|
||||
test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, false, eps, true));
|
||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps));
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false));
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true));
|
||||
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false, true));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9272,6 +9290,34 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
}
|
||||
}
|
||||
|
||||
struct conv3d_perf_case {
|
||||
int N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1, s2, p0, p1, p2, d0, d1, d2;
|
||||
};
|
||||
|
||||
const std::vector<conv3d_perf_case> conv3d_cases = {
|
||||
{1, 320, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 1280, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 320, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 1280, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 320, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
#if 0
|
||||
// too slow on some devices
|
||||
{1, 1280, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 320, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 640, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
#endif
|
||||
};
|
||||
|
||||
for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
for (const conv3d_perf_case & c : conv3d_cases) {
|
||||
test_cases.emplace_back(new test_conv_3d(
|
||||
c.N, c.IC, c.ID, c.IH, c.IW,
|
||||
c.OC, c.KD, c.KH, c.KW,
|
||||
c.s0, c.s1, c.s2, c.p0, c.p1, c.p2, c.d0, c.d1, c.d2,
|
||||
kernel_type));
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
|
||||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
|
||||
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
|
||||
set(TARGET llama-cli-impl)
|
||||
|
||||
add_library(${TARGET} cli.cpp
|
||||
cli-client.cpp
|
||||
cli-context.cpp)
|
||||
add_library(${TARGET} cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
|
||||
target_include_directories(${TARGET} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ../server)
|
||||
target_link_libraries(${TARGET} PUBLIC llama-server-impl llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_link_libraries(${TARGET} PUBLIC server-context llama-common ${CMAKE_THREAD_LIBS_INIT})
|
||||
|
||||
if(LLAMA_TOOLS_INSTALL)
|
||||
install(TARGETS ${TARGET} LIBRARY)
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
#include "cli-client.h"
|
||||
|
||||
#include "http.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
// generation can stall for a long time during prompt processing, so the
|
||||
// read timeout must be generous
|
||||
static constexpr time_t CLI_HTTP_READ_TIMEOUT_SEC = 3600;
|
||||
|
||||
// upper bound for the accumulated response body kept for error reporting
|
||||
static constexpr size_t CLI_HTTP_MAX_ERROR_BODY = 1024 * 1024;
|
||||
|
||||
// returns the path with the base url's path prefix prepended (if any)
|
||||
static std::string join_path(const common_http_url & parts, const std::string & path) {
|
||||
if (parts.path.empty() || parts.path == "/") {
|
||||
return path;
|
||||
}
|
||||
std::string prefix = parts.path;
|
||||
if (prefix.back() == '/') {
|
||||
prefix.pop_back();
|
||||
}
|
||||
return prefix + path;
|
||||
}
|
||||
|
||||
json cli_client::get(const std::string & path) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
auto path_with_model = path + (model.empty() ? "" : ("?model=" + model));
|
||||
auto res = cli.Get(join_path(parts, path_with_model));
|
||||
if (!res) {
|
||||
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
throw std::runtime_error("GET " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
|
||||
}
|
||||
json result = json::parse(res->body, nullptr, false);
|
||||
if (result.is_discarded()) {
|
||||
throw std::runtime_error("GET " + path + " returned invalid JSON");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
json cli_client::post(const std::string & path, const json & body) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
auto body_with_model = body;
|
||||
if (!model.empty()) {
|
||||
body_with_model["model"] = model;
|
||||
}
|
||||
auto res = cli.Post(join_path(parts, path), body_with_model.dump(), "application/json");
|
||||
if (!res) {
|
||||
throw std::runtime_error("failed to connect to " + server_base + ": " + httplib::to_string(res.error()));
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
throw std::runtime_error("POST " + path + " failed with status " + std::to_string(res->status) + ": " + res->body);
|
||||
}
|
||||
json result = json::parse(res->body, nullptr, false);
|
||||
if (result.is_discarded()) {
|
||||
throw std::runtime_error("POST " + path + " returned invalid JSON");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
json cli_client::post_sse(const std::string & path,
|
||||
const json & body,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_read_timeout(CLI_HTTP_READ_TIMEOUT_SEC, 0);
|
||||
|
||||
std::string pending; // buffer for incomplete SSE lines
|
||||
std::string raw_body; // accumulated body, used only for error reporting
|
||||
|
||||
auto receiver = [&](const char * data, size_t len) -> bool {
|
||||
if (should_stop()) {
|
||||
return false; // aborts the request
|
||||
}
|
||||
if (raw_body.size() < CLI_HTTP_MAX_ERROR_BODY) {
|
||||
raw_body.append(data, std::min(len, CLI_HTTP_MAX_ERROR_BODY - raw_body.size()));
|
||||
}
|
||||
pending.append(data, len);
|
||||
size_t pos;
|
||||
while ((pos = pending.find('\n')) != std::string::npos) {
|
||||
std::string line = pending.substr(0, pos);
|
||||
pending.erase(0, pos + 1);
|
||||
if (!line.empty() && line.back() == '\r') {
|
||||
line.pop_back();
|
||||
}
|
||||
if (line.rfind("data: ", 0) != 0) {
|
||||
continue;
|
||||
}
|
||||
std::string payload = line.substr(6);
|
||||
if (payload == "[DONE]") {
|
||||
continue;
|
||||
}
|
||||
json event = json::parse(payload, nullptr, false);
|
||||
if (!event.is_discarded()) {
|
||||
on_data(event);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
httplib::Headers headers = {{"Accept", "text/event-stream"}};
|
||||
auto body_with_model = body;
|
||||
if (!model.empty()) {
|
||||
body_with_model["model"] = model;
|
||||
}
|
||||
auto res = cli.Post(join_path(parts, path), headers, body_with_model.dump(), "application/json", receiver);
|
||||
|
||||
if (!res) {
|
||||
if (res.error() == httplib::Error::Canceled && should_stop()) {
|
||||
return json(); // cancelled by the user
|
||||
}
|
||||
return json {{"error", {{"message", "failed to connect to " + server_base + ": " + httplib::to_string(res.error())}}}};
|
||||
}
|
||||
if (res->status < 200 || res->status >= 300) {
|
||||
json error_body = json::parse(raw_body, nullptr, false);
|
||||
if (!error_body.is_discarded() && error_body.contains("error")) {
|
||||
return error_body;
|
||||
}
|
||||
return json {{"error", {{"message", "request failed with status " + std::to_string(res->status)}}}};
|
||||
}
|
||||
return json();
|
||||
}
|
||||
|
||||
bool cli_client::wait_health(const std::function<bool()> & is_aborted) {
|
||||
int connect_attempts = 0;
|
||||
while (!is_aborted()) {
|
||||
auto [cli, parts] = common_http_client(server_base);
|
||||
cli.set_connection_timeout(1, 0);
|
||||
auto res = cli.Get(join_path(parts, "/health"));
|
||||
if (res) {
|
||||
if (res->status == 200) {
|
||||
return true;
|
||||
}
|
||||
// any other status means the server is up but not ready yet
|
||||
// (e.g. 503 while the model is still loading)
|
||||
} else if (++connect_attempts >= 10) {
|
||||
last_error = "failed to connect to " + server_base + ": " + httplib::to_string(res.error());
|
||||
return false;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(300));
|
||||
}
|
||||
last_error = "aborted while waiting for the server to become ready";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> cli_client::list_models() {
|
||||
json resp = get("/v1/models");
|
||||
if (!resp.contains("data") || !resp.at("data").is_array()) {
|
||||
throw std::runtime_error("invalid response from /v1/models");
|
||||
}
|
||||
std::vector<std::string> models;
|
||||
for (const auto & m : resp.at("data")) {
|
||||
if (m.contains("id") && m.at("id").is_string()) {
|
||||
models.push_back(m.at("id").get<std::string>());
|
||||
}
|
||||
}
|
||||
return models;
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// openai-like client for CLI
|
||||
struct cli_client {
|
||||
std::string server_base; // base url, for example "http://127.0.0.1:8080"
|
||||
std::string last_error; // set when wait_health() fails
|
||||
|
||||
std::string model; // optional, set when the server has multiple models (router mode)
|
||||
|
||||
// simple GET request, returns the response json
|
||||
// throws std::runtime_error on transport error or non-2xx status
|
||||
json get(const std::string & path);
|
||||
|
||||
// simple POST request, returns the response json
|
||||
// throws std::runtime_error on transport error or non-2xx status
|
||||
json post(const std::string & path, const json & body);
|
||||
|
||||
// POST request with an SSE streaming response; on_data is invoked once
|
||||
// per "data:" event; the function returns after the stream is finished:
|
||||
// a null json on graceful exit (incl. cancellation via should_stop),
|
||||
// the error response json otherwise
|
||||
json post_sse(const std::string & path,
|
||||
const json & body,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data);
|
||||
|
||||
// poll /health until the server is ready to accept requests
|
||||
// returns false if is_aborted returned true or the server is unreachable
|
||||
bool wait_health(const std::function<bool()> & is_aborted);
|
||||
|
||||
//
|
||||
// higher-level wrappers
|
||||
//
|
||||
|
||||
json create_chat_completion(const json & request,
|
||||
const std::function<bool()> & should_stop,
|
||||
const std::function<void(const json &)> & on_data) {
|
||||
return post_sse("/v1/chat/completions", request, should_stop, on_data);
|
||||
}
|
||||
|
||||
json get_props() {
|
||||
return get("/props");
|
||||
}
|
||||
|
||||
std::vector<std::string> list_models();
|
||||
};
|
||||
@@ -1,559 +0,0 @@
|
||||
#include "cli-context.h"
|
||||
#include "cli-view.h"
|
||||
|
||||
#include "arg.h"
|
||||
#include "base64.hpp"
|
||||
#include "log.h"
|
||||
#include "console.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <set>
|
||||
|
||||
std::atomic<bool> g_cli_interrupted = false;
|
||||
|
||||
static bool should_stop() {
|
||||
return g_cli_interrupted.load();
|
||||
}
|
||||
|
||||
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
|
||||
|
||||
const char * LLAMA_ASCII_LOGO = R"(
|
||||
▄▄ ▄▄
|
||||
██ ██
|
||||
██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
|
||||
██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
|
||||
██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
|
||||
██ ██
|
||||
▀▀ ▀▀
|
||||
)";
|
||||
|
||||
// number of values an arg consumes on the command line
|
||||
static int arg_num_values(const common_arg & opt) {
|
||||
if (opt.value_hint_2 != nullptr) {
|
||||
return 2;
|
||||
}
|
||||
if (opt.value_hint != nullptr) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::string format_error_message(const json & err) {
|
||||
if (err.contains("error") && err.at("error").is_object()) {
|
||||
const auto & e = err.at("error");
|
||||
if (e.contains("message") && e.at("message").is_string()) {
|
||||
return e.at("message").get<std::string>();
|
||||
}
|
||||
}
|
||||
return err.dump();
|
||||
}
|
||||
|
||||
static std::string media_type_from_ext(const std::string & fname) {
|
||||
std::string ext = std::filesystem::path(fname).extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
if (ext == ".wav" || ext == ".mp3") {
|
||||
return "audio";
|
||||
}
|
||||
if (ext == ".mp4" || ext == ".avi" || ext == ".mkv" || ext == ".mov" || ext == ".webm") {
|
||||
return "video";
|
||||
}
|
||||
return "image";
|
||||
}
|
||||
|
||||
bool cli_context::init() {
|
||||
view::init(params);
|
||||
|
||||
std::optional<view::spinner> spinner;
|
||||
|
||||
bool use_external_server = !params.server_base.empty();
|
||||
if (use_external_server) {
|
||||
std::string base = params.server_base;
|
||||
while (!base.empty() && base.back() == '/') {
|
||||
base.pop_back();
|
||||
}
|
||||
client.server_base = base;
|
||||
|
||||
spinner.emplace("Connecting to server at " + base);
|
||||
} else {
|
||||
if (params.model.path.empty() && params.model.url.empty() &&
|
||||
params.model.hf_repo.empty() && params.model.docker_repo.empty()) {
|
||||
view::show_error(
|
||||
"no model specified",
|
||||
"use -m <file.gguf> or -hf <user/repo> to run a local model,\n"
|
||||
"or --server-base <url> to connect to a running llama-server"
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
spinner.emplace("\n\nLoading model...");
|
||||
|
||||
server.emplace();
|
||||
if (!server->start(params)) {
|
||||
view::show_error("server start failed");
|
||||
return false;
|
||||
}
|
||||
if (!server->wait_ready(should_stop)) {
|
||||
if (!should_stop()) {
|
||||
view::show_error("the server exited before becoming ready");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
client.server_base = server->address();
|
||||
}
|
||||
|
||||
// for --server-base this is the main availability check; for a spawned
|
||||
// server it is a cheap sanity check on top of the ready signal
|
||||
auto is_aborted = [this]() {
|
||||
return should_stop() || (server && !server->alive());
|
||||
};
|
||||
bool healthy = false;
|
||||
try {
|
||||
healthy = client.wait_health(is_aborted);
|
||||
} catch (const std::exception & e) {
|
||||
client.last_error = e.what();
|
||||
}
|
||||
if (!healthy) {
|
||||
if (!should_stop()) {
|
||||
view::show_error(client.last_error);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (use_external_server) {
|
||||
spinner.reset();
|
||||
if (!list_and_ask_models()) {
|
||||
return false;
|
||||
}
|
||||
// restore the spinner for the next step
|
||||
spinner.emplace("Waiting for server...");
|
||||
}
|
||||
|
||||
fetch_server_props();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void cli_context::fetch_server_props() {
|
||||
try {
|
||||
json props = client.get_props();
|
||||
model_name = props.value("model_alias", "");
|
||||
if (model_name.empty()) {
|
||||
const std::string path = props.value("model_path", "");
|
||||
if (!path.empty()) {
|
||||
model_name = std::filesystem::path(path).filename().string();
|
||||
}
|
||||
}
|
||||
build_info = props.value("build_info", "");
|
||||
if (props.contains("modalities") && props.at("modalities").is_object()) {
|
||||
const auto & modalities = props.at("modalities");
|
||||
has_vision = modalities.value("vision", false);
|
||||
has_audio = modalities.value("audio", false);
|
||||
has_video = modalities.value("video", false);
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
// /props can be disabled on remote servers; not fatal
|
||||
LOG_DBG("failed to fetch /props: %s\n", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
bool cli_context::list_and_ask_models() {
|
||||
auto models = client.list_models();
|
||||
|
||||
// only one model: use it without asking
|
||||
if (models.size() == 1) {
|
||||
model_name = models[0];
|
||||
client.model = model_name;
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string message = "\nAvailable models:";
|
||||
if (!models.empty()) {
|
||||
for (size_t i = 0; i < models.size(); ++i) {
|
||||
message += "\n " + std::to_string(i + 1) + ". " + models[i];
|
||||
}
|
||||
}
|
||||
message += "\n";
|
||||
view::show_message(message);
|
||||
std::string selection;
|
||||
while (selection.empty()) {
|
||||
if (should_stop()) {
|
||||
return false;
|
||||
}
|
||||
view::user_turn user_turn;
|
||||
selection = user_turn.read_input(false, "Select model by number: ");
|
||||
if (selection.empty()) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
size_t idx = std::stoul(selection);
|
||||
if (idx > 0 && idx <= models.size()) {
|
||||
model_name = models[idx - 1];
|
||||
client.model = model_name;
|
||||
view::show_message("Selected model: " + model_name);
|
||||
break;
|
||||
}
|
||||
} catch (...) {
|
||||
// ignore
|
||||
}
|
||||
view::show_error("Invalid selection. Please enter a valid number.");
|
||||
selection.clear();
|
||||
continue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void cli_context::add_system_prompt() {
|
||||
if (!params.system_prompt.empty()) {
|
||||
messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", params.system_prompt}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void cli_context::push_user_message(const std::string & text) {
|
||||
json content;
|
||||
if (pending_media.empty()) {
|
||||
content = text;
|
||||
} else {
|
||||
// multimodal message: media parts first, then the text
|
||||
content = pending_media;
|
||||
content.push_back({
|
||||
{"type", "text"},
|
||||
{"text", text}
|
||||
});
|
||||
pending_media = json::array();
|
||||
}
|
||||
messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", content}
|
||||
});
|
||||
}
|
||||
|
||||
bool cli_context::stage_media_file(const std::string & fname, const std::string & type) {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
return false;
|
||||
}
|
||||
std::string data((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
std::string encoded = base64::encode(data);
|
||||
|
||||
if (type == "audio") {
|
||||
std::string ext = std::filesystem::path(fname).extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||
pending_media.push_back({
|
||||
{"type", "input_audio"},
|
||||
{"input_audio", {
|
||||
{"data", encoded},
|
||||
{"format", ext == ".mp3" ? "mp3" : "wav"}
|
||||
}}
|
||||
});
|
||||
} else if (type == "video") {
|
||||
pending_media.push_back({
|
||||
{"type", "input_video"},
|
||||
{"input_video", {
|
||||
{"data", encoded}
|
||||
}}
|
||||
});
|
||||
} else {
|
||||
// the server detects the actual image type from the data
|
||||
pending_media.push_back({
|
||||
{"type", "image_url"},
|
||||
{"image_url", {
|
||||
{"url", "data:image/unknown;base64," + encoded}
|
||||
}}
|
||||
});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool cli_context::generate_completion(std::string & assistant_content, cli_timings & timings) {
|
||||
json body = {
|
||||
{"messages", messages},
|
||||
{"stream", true},
|
||||
// in order to get timings even when we cancel mid-way
|
||||
{"timings_per_token", true},
|
||||
};
|
||||
|
||||
bool stream_error = false;
|
||||
|
||||
view::assistant_turn a;
|
||||
|
||||
json err = client.create_chat_completion(body, should_stop, [&](const json & chunk) {
|
||||
if (chunk.contains("error")) {
|
||||
stream_error = true;
|
||||
view::show_error(format_error_message(chunk));
|
||||
return;
|
||||
}
|
||||
if (chunk.contains("timings")) {
|
||||
const auto & t = chunk.at("timings");
|
||||
timings.prompt_per_second = t.value("prompt_per_second", 0.0);
|
||||
timings.predicted_per_second = t.value("predicted_per_second", 0.0);
|
||||
}
|
||||
if (!chunk.contains("choices") || !chunk.at("choices").is_array() || chunk.at("choices").empty()) {
|
||||
return;
|
||||
}
|
||||
const auto & choice = chunk.at("choices").at(0);
|
||||
if (!choice.contains("delta")) {
|
||||
return;
|
||||
}
|
||||
const auto & delta = choice.at("delta");
|
||||
if (delta.contains("reasoning_content") && delta.at("reasoning_content").is_string()) {
|
||||
const std::string text = delta.at("reasoning_content").get<std::string>();
|
||||
if (!text.empty()) {
|
||||
a.push(view::ASSISTANT_DISPLAY_MODE_REASONING, text);
|
||||
}
|
||||
}
|
||||
if (delta.contains("content") && delta.at("content").is_string()) {
|
||||
const std::string text = delta.at("content").get<std::string>();
|
||||
if (!text.empty()) {
|
||||
assistant_content += text;
|
||||
a.push(view::ASSISTANT_DISPLAY_MODE_CONTENT, text);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
g_cli_interrupted.store(false);
|
||||
|
||||
if (!err.is_null()) {
|
||||
view::show_error(format_error_message(err));
|
||||
return false;
|
||||
}
|
||||
return !stream_error;
|
||||
}
|
||||
|
||||
int cli_context::run() {
|
||||
add_system_prompt();
|
||||
|
||||
std::string modalities = "text";
|
||||
if (has_vision) {
|
||||
modalities += ", vision";
|
||||
}
|
||||
if (has_audio) {
|
||||
modalities += ", audio";
|
||||
}
|
||||
if (has_video) {
|
||||
modalities += ", video";
|
||||
}
|
||||
|
||||
std::string banner;
|
||||
banner += "\n";
|
||||
banner += LLAMA_ASCII_LOGO;
|
||||
banner += "\n";
|
||||
banner += "build : " + build_info + "\n";
|
||||
banner += "model : " + model_name + "\n";
|
||||
banner += "modalities : " + modalities + "\n";
|
||||
if (!params.system_prompt.empty()) {
|
||||
banner += "using custom system prompt\n";
|
||||
}
|
||||
banner += "\n";
|
||||
banner += "available commands:\n";
|
||||
banner += " /exit or Ctrl+C stop or exit\n";
|
||||
banner += " /regen regenerate the last response\n";
|
||||
banner += " /clear clear the chat history\n";
|
||||
banner += " /read <file> add a text file\n";
|
||||
banner += " /glob <pattern> add text files using globbing pattern\n";
|
||||
if (has_vision) {
|
||||
banner += " /image <file> add an image file\n";
|
||||
}
|
||||
if (has_audio) {
|
||||
banner += " /audio <file> add an audio file\n";
|
||||
}
|
||||
if (has_video) {
|
||||
banner += " /video <file> add a video file\n";
|
||||
}
|
||||
banner += "\n";
|
||||
|
||||
view::show_message(banner);
|
||||
|
||||
// interactive loop
|
||||
std::string cur_msg;
|
||||
|
||||
auto add_text_file = [&](const std::string & fname) -> bool {
|
||||
std::ifstream file(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
return false;
|
||||
}
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
cur_msg += content;
|
||||
view::show_message(string_format("Loaded text from '%s'", fname.c_str()));
|
||||
return true;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
std::string buffer;
|
||||
{
|
||||
view::user_turn user_turn;
|
||||
|
||||
if (params.prompt.empty()) {
|
||||
buffer = user_turn.read_input(params.multiline_input);
|
||||
} else {
|
||||
// process input prompt from args
|
||||
for (auto & fname : params.image) {
|
||||
if (!stage_media_file(fname, media_type_from_ext(fname))) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
break;
|
||||
}
|
||||
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
|
||||
}
|
||||
buffer = params.prompt;
|
||||
user_turn.echo(buffer);
|
||||
params.prompt.clear(); // only use it once
|
||||
}
|
||||
}
|
||||
|
||||
if (should_stop()) {
|
||||
g_cli_interrupted.store(false);
|
||||
break;
|
||||
}
|
||||
|
||||
// remove trailing newline
|
||||
if (!buffer.empty() && buffer.back() == '\n') {
|
||||
buffer.pop_back();
|
||||
}
|
||||
|
||||
// skip empty messages
|
||||
if (buffer.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool add_user_msg = true;
|
||||
|
||||
// process commands
|
||||
if (string_starts_with(buffer, "/exit")) {
|
||||
break;
|
||||
} else if (string_starts_with(buffer, "/regen")) {
|
||||
if (messages.size() >= 2) {
|
||||
size_t last_idx = messages.size() - 1;
|
||||
messages.erase(last_idx);
|
||||
add_user_msg = false;
|
||||
} else {
|
||||
view::show_error("No message to regenerate.");
|
||||
continue;
|
||||
}
|
||||
} else if (string_starts_with(buffer, "/clear")) {
|
||||
messages.clear();
|
||||
add_system_prompt();
|
||||
|
||||
pending_media = json::array();
|
||||
view::show_message("Chat history cleared.");
|
||||
continue;
|
||||
} else if (
|
||||
(string_starts_with(buffer, "/image ") && has_vision) ||
|
||||
(string_starts_with(buffer, "/audio ") && has_audio) ||
|
||||
(string_starts_with(buffer, "/video ") && has_video)) {
|
||||
std::string type = buffer.substr(1, 5);
|
||||
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
|
||||
std::string fname = string_strip(buffer.substr(7));
|
||||
if (!stage_media_file(fname, type)) {
|
||||
view::show_error(string_format("file does not exist or cannot be opened: '%s'", fname.c_str()));
|
||||
continue;
|
||||
}
|
||||
view::show_message(string_format("Loaded media from '%s'", fname.c_str()));
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/read ")) {
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
add_text_file(fname);
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/glob ")) {
|
||||
std::error_code ec;
|
||||
size_t count = 0;
|
||||
auto curdir = std::filesystem::current_path();
|
||||
std::string pattern = string_strip(buffer.substr(6));
|
||||
std::filesystem::path rel_path;
|
||||
|
||||
auto startglob = pattern.find_first_of("![*?");
|
||||
if (startglob != std::string::npos && startglob != 0) {
|
||||
auto endpath = pattern.substr(0, startglob).find_last_of('/');
|
||||
if (endpath != std::string::npos) {
|
||||
std::string rel_pattern = pattern.substr(0, endpath);
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(rel_pattern, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
rel_pattern = home + rel_pattern.substr(1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
rel_path = rel_pattern;
|
||||
pattern.erase(0, endpath + 1);
|
||||
curdir /= rel_path;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
|
||||
std::filesystem::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(pattern, rel)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!add_text_file((rel_path / rel).string())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (++count >= FILE_GLOB_MAX_RESULTS) {
|
||||
view::show_error(string_format("Maximum number of globbed files allowed (%zu) reached.", FILE_GLOB_MAX_RESULTS));
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
// not a command
|
||||
cur_msg += buffer;
|
||||
}
|
||||
|
||||
// generate response
|
||||
if (add_user_msg) {
|
||||
push_user_message(cur_msg);
|
||||
cur_msg.clear();
|
||||
}
|
||||
cli_timings timings;
|
||||
std::string assistant_content;
|
||||
generate_completion(assistant_content, timings);
|
||||
messages.push_back({
|
||||
{"role", "assistant"},
|
||||
{"content", assistant_content}
|
||||
});
|
||||
|
||||
if (params.show_timings) {
|
||||
view::show_info(string_format(
|
||||
"\n[ Prompt: %.1f t/s | Generation: %.1f t/s ]",
|
||||
timings.prompt_per_second,
|
||||
timings.predicted_per_second
|
||||
));
|
||||
}
|
||||
|
||||
if (params.single_turn) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
view::show_message("\n\nExiting...");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void cli_context::shutdown() {
|
||||
if (server) {
|
||||
server->stop();
|
||||
server.reset();
|
||||
}
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
#include "cli-client.h"
|
||||
#include "cli-server.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
struct cli_timings {
|
||||
double prompt_per_second = 0.0;
|
||||
double predicted_per_second = 0.0;
|
||||
};
|
||||
|
||||
// set by the SIGINT handler; cleared once the interrupt has been handled
|
||||
extern std::atomic<bool> g_cli_interrupted;
|
||||
|
||||
struct cli_context {
|
||||
common_params params;
|
||||
|
||||
cli_client client; // always initialized
|
||||
std::optional<cli_server> server; // only set when no --server-base is given
|
||||
|
||||
json messages = json::array();
|
||||
json pending_media = json::array(); // staged multimodal content parts
|
||||
|
||||
// properties of the connected server
|
||||
// will be populated by fetch_server_props()
|
||||
std::string model_name;
|
||||
std::string build_info;
|
||||
bool has_vision = false;
|
||||
bool has_audio = false;
|
||||
bool has_video = false;
|
||||
|
||||
cli_context(const common_params & params) : params(params) {}
|
||||
~cli_context() {
|
||||
shutdown();
|
||||
}
|
||||
|
||||
// connect to --server-base or spawn a local llama-server child;
|
||||
// argc/argv are needed to forward the server-relevant args to the child
|
||||
bool init();
|
||||
|
||||
// run the interactive chat loop, returns the process exit code
|
||||
int run();
|
||||
|
||||
// stop the local server child (if any)
|
||||
void shutdown();
|
||||
|
||||
private:
|
||||
bool generate_completion(std::string & assistant_content, cli_timings & timings);
|
||||
void fetch_server_props();
|
||||
void add_system_prompt();
|
||||
void push_user_message(const std::string & text);
|
||||
|
||||
// check if server have multiple models (router mode)
|
||||
// if yes, list them then ask; do nothing otherwise
|
||||
bool list_and_ask_models();
|
||||
|
||||
// read a file and stage it as a multimodal content part; type is one of
|
||||
// "image", "audio", "video"; returns false if the file cannot be read
|
||||
bool stage_media_file(const std::string & fname, const std::string & type);
|
||||
};
|
||||
@@ -1,83 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include "http.h"
|
||||
|
||||
// llama_server will be available as a dynamic library symbol
|
||||
int llama_server(common_params & params, int argc, char ** argv);
|
||||
void llama_server_terminate();
|
||||
|
||||
struct cli_server {
|
||||
std::thread th;
|
||||
int port = -1;
|
||||
std::atomic<bool> is_alive = false;
|
||||
std::atomic<bool> is_stopping = false;
|
||||
|
||||
~cli_server() {
|
||||
stop();
|
||||
}
|
||||
|
||||
void stop() {
|
||||
if (alive() && !is_stopping.exchange(true)) {
|
||||
llama_server_terminate();
|
||||
th.join();
|
||||
}
|
||||
}
|
||||
|
||||
// spawn llama-server in a thread and interact with it via a random port
|
||||
bool start(common_params & params) {
|
||||
port = common_http_get_free_port();
|
||||
if (port <= 0) {
|
||||
fprintf(stderr, "failed to get a free port\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
is_alive.store(true, std::memory_order_release);
|
||||
|
||||
th = std::thread([&]() {
|
||||
common_params server_params = params; // copy
|
||||
server_params.port = port;
|
||||
// argc / argv are only used in router mode, we can skip them for now
|
||||
int res = llama_server(server_params, 0, nullptr);
|
||||
if (res != 0) {
|
||||
fprintf(stderr, "llama_server exited with code %d\n", res);
|
||||
}
|
||||
is_alive.store(false, std::memory_order_release);
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string address() const {
|
||||
return "http://127.0.0.1:" + std::to_string(port);
|
||||
}
|
||||
|
||||
bool wait_ready(std::function<bool()> should_stop) {
|
||||
if (!alive()) {
|
||||
return false;
|
||||
}
|
||||
while (!should_stop()) {
|
||||
auto [cli, parts] = common_http_client(address());
|
||||
cli.set_connection_timeout(1, 0);
|
||||
auto res = cli.Get("/health");
|
||||
if (res) {
|
||||
if (res->status == 200) {
|
||||
return true;
|
||||
}
|
||||
// any other status means the server is up but not ready yet
|
||||
// (e.g. 503 while the model is still loading)
|
||||
}
|
||||
if (!alive()) {
|
||||
// in case server die permanently
|
||||
return false;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(200));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool alive() const {
|
||||
return is_alive.load(std::memory_order_acquire);
|
||||
}
|
||||
};
|
||||
@@ -1,250 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "console.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <string_view>
|
||||
|
||||
// TODO?: Make this reusable, enums, docs
|
||||
static const std::array<std::string_view, 8> cmds = {
|
||||
"/audio ",
|
||||
"/clear",
|
||||
"/exit",
|
||||
"/glob ",
|
||||
"/image ",
|
||||
"/read ",
|
||||
"/regen",
|
||||
"/video ",
|
||||
};
|
||||
|
||||
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
|
||||
std::vector<std::pair<std::string, size_t>> matches;
|
||||
std::string cmd;
|
||||
|
||||
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return string_starts_with(line, prefix);
|
||||
})) {
|
||||
auto it = cmds.begin();
|
||||
|
||||
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
|
||||
return string_starts_with(cmd_line, line);
|
||||
})) != cmds.end()) {
|
||||
matches.emplace_back(*it, it->length());
|
||||
++it;
|
||||
}
|
||||
} else {
|
||||
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return prefix.back() == ' ' && string_starts_with(line, prefix);
|
||||
});
|
||||
|
||||
if (it != cmds.end()) {
|
||||
cmd = *it;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
|
||||
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
|
||||
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
|
||||
auto cur_dir = std::filesystem::current_path();
|
||||
std::string cur_dir_str = cur_dir.string();
|
||||
std::string expanded_prefix = path_prefix;
|
||||
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(path_prefix, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
expanded_prefix = home + path_prefix.substr(1);
|
||||
}
|
||||
}
|
||||
if (string_starts_with(expanded_prefix, '/')) {
|
||||
#else
|
||||
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
|
||||
#endif
|
||||
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
|
||||
cur_dir_str.clear();
|
||||
} else if (!path_prefix.empty()) {
|
||||
cur_dir /= std::filesystem::path(path_prefix).parent_path();
|
||||
}
|
||||
|
||||
std::error_code ec;
|
||||
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
|
||||
if (ec) {
|
||||
break;
|
||||
}
|
||||
if (!entry.exists(ec)) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string path_full = entry.path().string();
|
||||
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
|
||||
|
||||
if (entry.is_directory(ec)) {
|
||||
path_entry.push_back(std::filesystem::path::preferred_separator);
|
||||
}
|
||||
|
||||
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
|
||||
const std::string updated_line = cmd + path_entry;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (matches.empty()) {
|
||||
const std::string updated_line = cmd + path_prefix;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
// Add the longest common prefix
|
||||
if (!expanded_prefix.empty() && matches.size() > 1) {
|
||||
const std::string_view match0(matches[0].first);
|
||||
const std::string_view match1(matches[1].first);
|
||||
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
|
||||
size_t len = it.first - match0.begin();
|
||||
|
||||
for (size_t i = 2; i < matches.size(); ++i) {
|
||||
const std::string_view matchi(matches[i].first);
|
||||
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
|
||||
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
|
||||
}
|
||||
|
||||
const std::string updated_line = std::string(match0.substr(0, len));
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
|
||||
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
|
||||
});
|
||||
}
|
||||
|
||||
return matches;
|
||||
}
|
||||
|
||||
// note: make this view implementation generic, so that we can move to TUI in the future if we want to
|
||||
namespace view {
|
||||
static void init(const common_params & params) {
|
||||
// TODO: avoid using atexit() here by making `console` a singleton
|
||||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
console::set_completion_callback(auto_completion_callback);
|
||||
}
|
||||
|
||||
struct spinner {
|
||||
spinner(const std::string & message) {
|
||||
if (!message.empty()) {
|
||||
console::log("%s ", message.c_str());
|
||||
}
|
||||
console::spinner::start();
|
||||
}
|
||||
~spinner() {
|
||||
console::spinner::stop();
|
||||
}
|
||||
};
|
||||
|
||||
struct user_turn {
|
||||
user_turn() {
|
||||
console::set_display(DISPLAY_TYPE_USER_INPUT);
|
||||
}
|
||||
~user_turn() {
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
void echo(const std::string & buffer) {
|
||||
if (buffer.size() > 500) {
|
||||
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
|
||||
} else {
|
||||
console::log("\n> %s\n", buffer.c_str());
|
||||
}
|
||||
}
|
||||
std::string read_input(bool multiline_input, const char * prompt = nullptr) {
|
||||
if (prompt) {
|
||||
console::log("%s", prompt);
|
||||
} else {
|
||||
console::log("\n> ");
|
||||
}
|
||||
std::string buffer;
|
||||
std::string line;
|
||||
bool another_line = true;
|
||||
do {
|
||||
another_line = console::readline(line, multiline_input);
|
||||
buffer += line;
|
||||
} while (another_line);
|
||||
return buffer;
|
||||
}
|
||||
};
|
||||
|
||||
enum assistant_display_mode {
|
||||
ASSISTANT_DISPLAY_MODE_REASONING,
|
||||
ASSISTANT_DISPLAY_MODE_CONTENT,
|
||||
};
|
||||
struct assistant_turn {
|
||||
assistant_display_mode mode = ASSISTANT_DISPLAY_MODE_CONTENT;
|
||||
bool trailing_newline = true;
|
||||
bool is_inside_reasoning = false;
|
||||
assistant_turn() {
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
~assistant_turn() {
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
add_newline_if_needed();
|
||||
}
|
||||
void push(assistant_display_mode m, const std::string & buffer) {
|
||||
if (m != mode) {
|
||||
add_newline_if_needed();
|
||||
switch (m) {
|
||||
case ASSISTANT_DISPLAY_MODE_CONTENT:
|
||||
{
|
||||
if (is_inside_reasoning) {
|
||||
console::log("[End thinking]\n\n");
|
||||
is_inside_reasoning = false;
|
||||
}
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
} break;
|
||||
case ASSISTANT_DISPLAY_MODE_REASONING:
|
||||
{
|
||||
console::set_display(DISPLAY_TYPE_REASONING);
|
||||
is_inside_reasoning = true;
|
||||
console::log("\n[Start thinking]\n\n");
|
||||
} break;
|
||||
}
|
||||
}
|
||||
mode = m;
|
||||
if (buffer.empty()) {
|
||||
return;
|
||||
}
|
||||
trailing_newline = buffer.back() == '\n';
|
||||
console::log("%s", buffer.c_str());
|
||||
console::flush();
|
||||
}
|
||||
void add_newline_if_needed() {
|
||||
if (!trailing_newline) {
|
||||
console::log("\n");
|
||||
console::flush();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static void show_error(const std::string & title, const std::string & message = "") {
|
||||
console::spinner::stop();
|
||||
console::error("Error: %s\n", title.c_str());
|
||||
if (!message.empty()) {
|
||||
console::log("%s\n", message.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
static void show_message(const std::string & message) {
|
||||
console::log("%s\n", message.c_str());
|
||||
}
|
||||
|
||||
static void show_info(const std::string & message) {
|
||||
console::set_display(DISPLAY_TYPE_INFO);
|
||||
console::log("%s\n", message.c_str());
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
}
|
||||
+624
-10
@@ -1,10 +1,20 @@
|
||||
#include "arg.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "arg.h"
|
||||
#include "console.h"
|
||||
#include "fit.h"
|
||||
// #include "log.h"
|
||||
|
||||
#include "cli-context.h"
|
||||
#include "cli-view.h"
|
||||
#include "server-common.h"
|
||||
#include "server-context.h"
|
||||
#include "server-task.h"
|
||||
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <algorithm>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <thread>
|
||||
#include <signal.h>
|
||||
|
||||
#if defined(_WIN32)
|
||||
@@ -15,19 +25,342 @@
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
const char * LLAMA_ASCII_LOGO = R"(
|
||||
▄▄ ▄▄
|
||||
██ ██
|
||||
██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
|
||||
██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
|
||||
██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
|
||||
██ ██
|
||||
▀▀ ▀▀
|
||||
)";
|
||||
|
||||
static std::atomic<bool> g_is_interrupted = false;
|
||||
static bool should_stop() {
|
||||
return g_is_interrupted.load();
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void signal_handler(int) {
|
||||
if (g_cli_interrupted.load()) {
|
||||
if (g_is_interrupted.load()) {
|
||||
// second Ctrl+C - exit immediately
|
||||
// make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
|
||||
fprintf(stdout, "\033[0m\n");
|
||||
fflush(stdout);
|
||||
std::exit(130);
|
||||
}
|
||||
g_cli_interrupted.store(true);
|
||||
g_is_interrupted.store(true);
|
||||
}
|
||||
#endif
|
||||
|
||||
struct cli_context {
|
||||
server_context ctx_server;
|
||||
json messages = json::array();
|
||||
std::vector<raw_buffer> input_files;
|
||||
task_params defaults;
|
||||
bool verbose_prompt;
|
||||
|
||||
// thread for showing "loading" animation
|
||||
std::atomic<bool> loading_show;
|
||||
|
||||
cli_context(const common_params & params) {
|
||||
defaults.sampling = params.sampling;
|
||||
defaults.speculative = params.speculative;
|
||||
defaults.n_keep = params.n_keep;
|
||||
defaults.n_predict = params.n_predict;
|
||||
defaults.antiprompt = params.antiprompt;
|
||||
|
||||
defaults.stream = true; // make sure we always use streaming mode
|
||||
defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
|
||||
// defaults.return_progress = true; // TODO: show progress
|
||||
|
||||
verbose_prompt = params.verbose_prompt;
|
||||
}
|
||||
|
||||
std::string generate_completion(result_timings & out_timings) {
|
||||
server_response_reader rd = ctx_server.get_response_reader();
|
||||
auto chat_params = format_chat();
|
||||
{
|
||||
// TODO: reduce some copies here in the future
|
||||
server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
|
||||
task.id = rd.get_new_id();
|
||||
task.index = 0;
|
||||
task.params = defaults; // copy
|
||||
task.cli_prompt = chat_params.prompt; // copy
|
||||
task.cli_files = input_files; // copy
|
||||
task.cli = true;
|
||||
|
||||
// chat template settings
|
||||
task.params.chat_parser_params = common_chat_parser_params(chat_params);
|
||||
task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
if (!chat_params.parser.empty()) {
|
||||
task.params.chat_parser_params.parser.load(chat_params.parser);
|
||||
}
|
||||
|
||||
// Copy the preserved tokens into the sampling params
|
||||
const llama_vocab * vocab = llama_model_get_vocab(
|
||||
llama_get_model(ctx_server.get_llama_context()));
|
||||
for (const auto & token : chat_params.preserved_tokens) {
|
||||
auto ids = common_tokenize(vocab, token, false, true);
|
||||
if (ids.size() == 1) {
|
||||
task.params.sampling.preserved_tokens.insert(ids[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// reasoning budget sampler
|
||||
if (!chat_params.thinking_end_tag.empty()) {
|
||||
task.params.sampling.reasoning_budget_tokens = defaults.sampling.reasoning_budget_tokens;
|
||||
task.params.sampling.generation_prompt = chat_params.generation_prompt;
|
||||
|
||||
if (!chat_params.thinking_start_tag.empty()) {
|
||||
task.params.sampling.reasoning_budget_start =
|
||||
common_tokenize(vocab, chat_params.thinking_start_tag, false, true);
|
||||
}
|
||||
task.params.sampling.reasoning_budget_end =
|
||||
common_tokenize(vocab, chat_params.thinking_end_tag, false, true);
|
||||
task.params.sampling.reasoning_budget_forced =
|
||||
common_tokenize(vocab, defaults.sampling.reasoning_budget_message + chat_params.thinking_end_tag, false, true);
|
||||
}
|
||||
|
||||
rd.post_task({std::move(task)});
|
||||
}
|
||||
|
||||
if (verbose_prompt) {
|
||||
console::set_display(DISPLAY_TYPE_PROMPT);
|
||||
console::log("%s\n\n", chat_params.prompt.c_str());
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
|
||||
// wait for first result
|
||||
console::spinner::start();
|
||||
server_task_result_ptr result = rd.next(should_stop);
|
||||
|
||||
while (true) {
|
||||
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
|
||||
if (res_partial && res_partial->is_begin) {
|
||||
// this is the "send 200 status to client" signal in streaming mode
|
||||
// skip, do not stop the spinner
|
||||
result = rd.next(should_stop);
|
||||
} else {
|
||||
console::spinner::stop();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::string curr_content;
|
||||
bool is_thinking = false;
|
||||
|
||||
while (result) {
|
||||
if (should_stop()) {
|
||||
break;
|
||||
}
|
||||
if (result->is_error()) {
|
||||
json err_data = result->to_json();
|
||||
if (err_data.contains("message")) {
|
||||
console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
|
||||
} else {
|
||||
console::error("Error: %s\n", err_data.dump().c_str());
|
||||
}
|
||||
return curr_content;
|
||||
}
|
||||
auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
|
||||
if (res_partial) {
|
||||
out_timings = std::move(res_partial->timings);
|
||||
for (const auto & diff : res_partial->oaicompat_msg_diffs) {
|
||||
if (!diff.content_delta.empty()) {
|
||||
if (is_thinking) {
|
||||
console::log("\n[End thinking]\n\n");
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
is_thinking = false;
|
||||
}
|
||||
curr_content += diff.content_delta;
|
||||
console::log("%s", diff.content_delta.c_str());
|
||||
console::flush();
|
||||
}
|
||||
if (!diff.reasoning_content_delta.empty()) {
|
||||
console::set_display(DISPLAY_TYPE_REASONING);
|
||||
if (!is_thinking) {
|
||||
console::log("[Start thinking]\n");
|
||||
}
|
||||
is_thinking = true;
|
||||
console::log("%s", diff.reasoning_content_delta.c_str());
|
||||
console::flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
|
||||
if (res_final) {
|
||||
out_timings = std::move(res_final->timings);
|
||||
break;
|
||||
}
|
||||
result = rd.next(should_stop);
|
||||
}
|
||||
g_is_interrupted.store(false);
|
||||
// server_response_reader automatically cancels pending tasks upon destruction
|
||||
return curr_content;
|
||||
}
|
||||
|
||||
// TODO: support remote files in the future (http, https, etc)
|
||||
std::string load_input_file(const std::string & fname, bool is_media) {
|
||||
std::ifstream file = fs_open_ifstream(fname, std::ios::binary);
|
||||
if (!file) {
|
||||
return "";
|
||||
}
|
||||
if (is_media) {
|
||||
raw_buffer buf;
|
||||
buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
input_files.push_back(std::move(buf));
|
||||
return get_media_marker();
|
||||
} else {
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
return content;
|
||||
}
|
||||
}
|
||||
|
||||
common_chat_params format_chat() {
|
||||
auto meta = ctx_server.get_meta();
|
||||
auto & chat_params = meta.chat_params;
|
||||
|
||||
auto caps = common_chat_templates_get_caps(chat_params.tmpls.get());
|
||||
|
||||
common_chat_templates_inputs inputs;
|
||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||
inputs.tools = {}; // TODO
|
||||
inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||
inputs.json_schema = ""; // TODO
|
||||
inputs.grammar = ""; // TODO
|
||||
inputs.use_jinja = chat_params.use_jinja;
|
||||
inputs.parallel_tool_calls = caps["supports_parallel_tool_calls"];
|
||||
inputs.add_generation_prompt = true;
|
||||
inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
|
||||
inputs.force_pure_content = chat_params.force_pure_content;
|
||||
inputs.enable_thinking = chat_params.enable_thinking ? common_chat_templates_support_enable_thinking(chat_params.tmpls.get()) : false;
|
||||
|
||||
// Apply chat template to the list of messages
|
||||
return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO?: Make this reusable, enums, docs
|
||||
static const std::array<std::string_view, 8> cmds = {
|
||||
"/audio ",
|
||||
"/clear",
|
||||
"/exit",
|
||||
"/glob ",
|
||||
"/image ",
|
||||
"/read ",
|
||||
"/regen",
|
||||
"/video ",
|
||||
};
|
||||
|
||||
static std::vector<std::pair<std::string, size_t>> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) {
|
||||
std::vector<std::pair<std::string, size_t>> matches;
|
||||
std::string cmd;
|
||||
|
||||
if (line.length() > 1 && line.front() == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return string_starts_with(line, prefix);
|
||||
})) {
|
||||
auto it = cmds.begin();
|
||||
|
||||
while ((it = std::find_if(it, cmds.end(), [line](std::string_view cmd_line) {
|
||||
return string_starts_with(cmd_line, line);
|
||||
})) != cmds.end()) {
|
||||
matches.emplace_back(*it, it->length());
|
||||
++it;
|
||||
}
|
||||
} else {
|
||||
auto it = std::find_if(cmds.begin(), cmds.end(), [line](std::string_view prefix) {
|
||||
return prefix.back() == ' ' && string_starts_with(line, prefix);
|
||||
});
|
||||
|
||||
if (it != cmds.end()) {
|
||||
cmd = *it;
|
||||
}
|
||||
}
|
||||
|
||||
if (!cmd.empty() && cmd != "/glob " && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) {
|
||||
const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length()));
|
||||
const std::string path_postfix = std::string(line.substr(cursor_byte_pos));
|
||||
auto cur_dir = std::filesystem::current_path();
|
||||
std::string cur_dir_str = cur_dir.string();
|
||||
std::string expanded_prefix = path_prefix;
|
||||
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(path_prefix, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
expanded_prefix = home + path_prefix.substr(1);
|
||||
}
|
||||
}
|
||||
if (string_starts_with(expanded_prefix, '/')) {
|
||||
#else
|
||||
if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) {
|
||||
#endif
|
||||
cur_dir = std::filesystem::path(expanded_prefix).parent_path();
|
||||
cur_dir_str.clear();
|
||||
} else if (!path_prefix.empty()) {
|
||||
cur_dir /= std::filesystem::path(path_prefix).parent_path();
|
||||
}
|
||||
|
||||
std::error_code ec;
|
||||
for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) {
|
||||
if (ec) {
|
||||
break;
|
||||
}
|
||||
if (!entry.exists(ec)) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string path_full = entry.path().string();
|
||||
std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full;
|
||||
|
||||
if (entry.is_directory(ec)) {
|
||||
path_entry.push_back(std::filesystem::path::preferred_separator);
|
||||
}
|
||||
|
||||
if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) {
|
||||
const std::string updated_line = cmd + path_entry;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if (matches.empty()) {
|
||||
const std::string updated_line = cmd + path_prefix;
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
// Add the longest common prefix
|
||||
if (!expanded_prefix.empty() && matches.size() > 1) {
|
||||
const std::string_view match0(matches[0].first);
|
||||
const std::string_view match1(matches[1].first);
|
||||
auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end());
|
||||
size_t len = it.first - match0.begin();
|
||||
|
||||
for (size_t i = 2; i < matches.size(); ++i) {
|
||||
const std::string_view matchi(matches[i].first);
|
||||
auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end());
|
||||
len = std::min(len, static_cast<size_t>(cmp.first - match0.begin()));
|
||||
}
|
||||
|
||||
const std::string updated_line = std::string(match0.substr(0, len));
|
||||
matches.emplace_back(updated_line + path_postfix, updated_line.length());
|
||||
}
|
||||
|
||||
std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) {
|
||||
return a.first.compare(0, a.second, b.first, 0, b.second) < 0;
|
||||
});
|
||||
}
|
||||
|
||||
return matches;
|
||||
}
|
||||
|
||||
static constexpr size_t FILE_GLOB_MAX_RESULTS = 100;
|
||||
|
||||
// satisfies -Wmissing-declarations
|
||||
int llama_cli(int argc, char ** argv);
|
||||
|
||||
@@ -42,6 +375,25 @@ int llama_cli(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// TODO: maybe support it later?
|
||||
if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
|
||||
console::error("--no-conversation is not supported by llama-cli\n");
|
||||
console::error("please use llama-completion instead\n");
|
||||
}
|
||||
|
||||
// struct that contains llama context and inference
|
||||
cli_context ctx_cli(params);
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// TODO: avoid using atexit() here by making `console` a singleton
|
||||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
console::set_completion_callback(auto_completion_callback);
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
@@ -56,11 +408,273 @@ int llama_cli(int argc, char ** argv) {
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
cli_context ctx_cli(params);
|
||||
|
||||
if (!ctx_cli.init()) {
|
||||
console::log("\nLoading model... "); // followed by loading animation
|
||||
console::spinner::start();
|
||||
if (!ctx_cli.ctx_server.load_model(params)) {
|
||||
console::spinner::stop();
|
||||
console::error("\nFailed to load the model\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
return ctx_cli.run();
|
||||
ctx_cli.defaults.sampling = params.sampling;
|
||||
|
||||
console::spinner::stop();
|
||||
console::log("\n");
|
||||
|
||||
std::thread inference_thread([&ctx_cli]() {
|
||||
ctx_cli.ctx_server.start_loop();
|
||||
});
|
||||
|
||||
auto inf = ctx_cli.ctx_server.get_meta();
|
||||
std::string modalities = "text";
|
||||
if (inf.has_inp_image) {
|
||||
modalities += ", vision";
|
||||
}
|
||||
if (inf.has_inp_audio) {
|
||||
modalities += ", audio";
|
||||
}
|
||||
|
||||
auto add_system_prompt = [&]() {
|
||||
if (!params.system_prompt.empty()) {
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "system"},
|
||||
{"content", params.system_prompt}
|
||||
});
|
||||
}
|
||||
};
|
||||
add_system_prompt();
|
||||
|
||||
console::log("\n");
|
||||
console::log("%s\n", LLAMA_ASCII_LOGO);
|
||||
console::log("build : %s\n", inf.build_info.c_str());
|
||||
console::log("model : %s\n", inf.model_name.c_str());
|
||||
console::log("modalities : %s\n", modalities.c_str());
|
||||
if (!params.system_prompt.empty()) {
|
||||
console::log("using custom system prompt\n");
|
||||
}
|
||||
console::log("\n");
|
||||
console::log("available commands:\n");
|
||||
console::log(" /exit or Ctrl+C stop or exit\n");
|
||||
console::log(" /regen regenerate the last response\n");
|
||||
console::log(" /clear clear the chat history\n");
|
||||
console::log(" /read <file> add a text file\n");
|
||||
console::log(" /glob <pattern> add text files using globbing pattern\n");
|
||||
if (inf.has_inp_image) {
|
||||
console::log(" /image <file> add an image file\n");
|
||||
}
|
||||
if (inf.has_inp_audio) {
|
||||
console::log(" /audio <file> add an audio file\n");
|
||||
}
|
||||
if (inf.has_inp_video) {
|
||||
console::log(" /video <file> add a video file\n");
|
||||
}
|
||||
console::log("\n");
|
||||
|
||||
// interactive loop
|
||||
std::string cur_msg;
|
||||
|
||||
auto add_text_file = [&](const std::string & fname) -> bool {
|
||||
std::string marker = ctx_cli.load_input_file(fname, false);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
return false;
|
||||
}
|
||||
if (inf.fim_sep_token != LLAMA_TOKEN_NULL) {
|
||||
cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true);
|
||||
cur_msg += fname;
|
||||
cur_msg.push_back('\n');
|
||||
} else {
|
||||
cur_msg += "--- File: ";
|
||||
cur_msg += fname;
|
||||
cur_msg += " ---\n";
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded text from '%s'\n", fname.c_str());
|
||||
return true;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
std::string buffer;
|
||||
console::set_display(DISPLAY_TYPE_USER_INPUT);
|
||||
if (params.prompt.empty()) {
|
||||
console::log("\n> ");
|
||||
std::string line;
|
||||
bool another_line = true;
|
||||
do {
|
||||
another_line = console::readline(line, params.multiline_input);
|
||||
buffer += line;
|
||||
} while (another_line);
|
||||
} else {
|
||||
// process input prompt from args
|
||||
for (auto & fname : params.image) {
|
||||
std::string marker = ctx_cli.load_input_file(fname, true);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
break;
|
||||
}
|
||||
console::log("Loaded media from '%s'\n", fname.c_str());
|
||||
cur_msg += marker;
|
||||
}
|
||||
buffer = params.prompt;
|
||||
if (buffer.size() > 500) {
|
||||
console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
|
||||
} else {
|
||||
console::log("\n> %s\n", buffer.c_str());
|
||||
}
|
||||
params.prompt.clear(); // only use it once
|
||||
}
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
console::log("\n");
|
||||
|
||||
if (should_stop()) {
|
||||
g_is_interrupted.store(false);
|
||||
break;
|
||||
}
|
||||
|
||||
// remove trailing newline
|
||||
if (!buffer.empty() &&buffer.back() == '\n') {
|
||||
buffer.pop_back();
|
||||
}
|
||||
|
||||
// skip empty messages
|
||||
if (buffer.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool add_user_msg = true;
|
||||
|
||||
// process commands
|
||||
if (string_starts_with(buffer, "/exit")) {
|
||||
break;
|
||||
} else if (string_starts_with(buffer, "/regen")) {
|
||||
if (ctx_cli.messages.size() >= 2) {
|
||||
size_t last_idx = ctx_cli.messages.size() - 1;
|
||||
ctx_cli.messages.erase(last_idx);
|
||||
add_user_msg = false;
|
||||
} else {
|
||||
console::error("No message to regenerate.\n");
|
||||
continue;
|
||||
}
|
||||
} else if (string_starts_with(buffer, "/clear")) {
|
||||
ctx_cli.messages.clear();
|
||||
add_system_prompt();
|
||||
|
||||
ctx_cli.input_files.clear();
|
||||
console::log("Chat history cleared.\n");
|
||||
continue;
|
||||
} else if (
|
||||
(string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
|
||||
(string_starts_with(buffer, "/audio ") && inf.has_inp_audio) ||
|
||||
(string_starts_with(buffer, "/video ") && inf.has_inp_video)) {
|
||||
// just in case (bad copy-paste for example), we strip all trailing/leading spaces
|
||||
std::string fname = string_strip(buffer.substr(7));
|
||||
std::string marker = ctx_cli.load_input_file(fname, true);
|
||||
if (marker.empty()) {
|
||||
console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
|
||||
continue;
|
||||
}
|
||||
cur_msg += marker;
|
||||
console::log("Loaded media from '%s'\n", fname.c_str());
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/read ")) {
|
||||
std::string fname = string_strip(buffer.substr(6));
|
||||
add_text_file(fname);
|
||||
continue;
|
||||
} else if (string_starts_with(buffer, "/glob ")) {
|
||||
std::error_code ec;
|
||||
size_t count = 0;
|
||||
auto curdir = std::filesystem::current_path();
|
||||
std::string pattern = string_strip(buffer.substr(6));
|
||||
std::filesystem::path rel_path;
|
||||
|
||||
auto startglob = pattern.find_first_of("![*?");
|
||||
if (startglob != std::string::npos && startglob != 0) {
|
||||
auto endpath = pattern.substr(0, startglob).find_last_of('/');
|
||||
if (endpath != std::string::npos) {
|
||||
std::string rel_pattern = pattern.substr(0, endpath);
|
||||
#if !defined(_WIN32)
|
||||
if (string_starts_with(rel_pattern, '~')) {
|
||||
const char * home = std::getenv("HOME");
|
||||
if (home && home[0]) {
|
||||
rel_pattern = home + rel_pattern.substr(1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
rel_path = rel_pattern;
|
||||
pattern.erase(0, endpath + 1);
|
||||
curdir /= rel_path;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto & entry : std::filesystem::recursive_directory_iterator(curdir,
|
||||
std::filesystem::directory_options::skip_permission_denied, ec)) {
|
||||
if (!entry.is_regular_file()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string rel = std::filesystem::relative(entry.path(), curdir, ec).string();
|
||||
if (ec) {
|
||||
ec.clear();
|
||||
continue;
|
||||
}
|
||||
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||
|
||||
if (!glob_match(pattern, rel)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!add_text_file((rel_path / rel).string())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (++count >= FILE_GLOB_MAX_RESULTS) {
|
||||
console::error("Maximum number of globbed files allowed (%zu) reached.\n", FILE_GLOB_MAX_RESULTS);
|
||||
break;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} else {
|
||||
// not a command
|
||||
cur_msg += buffer;
|
||||
}
|
||||
|
||||
// generate response
|
||||
if (add_user_msg) {
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "user"},
|
||||
{"content", cur_msg}
|
||||
});
|
||||
cur_msg.clear();
|
||||
}
|
||||
result_timings timings;
|
||||
std::string assistant_content = ctx_cli.generate_completion(timings);
|
||||
ctx_cli.messages.push_back({
|
||||
{"role", "assistant"},
|
||||
{"content", assistant_content}
|
||||
});
|
||||
console::log("\n");
|
||||
|
||||
if (params.show_timings) {
|
||||
console::set_display(DISPLAY_TYPE_INFO);
|
||||
console::log("\n");
|
||||
console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
}
|
||||
|
||||
if (params.single_turn) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
console::set_display(DISPLAY_TYPE_RESET);
|
||||
|
||||
console::log("\nExiting...\n");
|
||||
ctx_cli.ctx_server.terminate();
|
||||
inference_thread.join();
|
||||
|
||||
// bump the log level to display timings
|
||||
common_log_set_verbosity_thold(LOG_LEVEL_INFO);
|
||||
common_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -89,7 +89,9 @@ struct server_batch {
|
||||
}
|
||||
|
||||
~server_batch() {
|
||||
llama_batch_free(batch);
|
||||
if (batch.token != nullptr) {
|
||||
llama_batch_free(batch);
|
||||
}
|
||||
}
|
||||
|
||||
void init(int32_t n_tokens_alloc) {
|
||||
@@ -1215,6 +1217,10 @@ private:
|
||||
cparams.ctx_other = ctx_tgt;
|
||||
|
||||
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
|
||||
if (ctx_dft == nullptr) {
|
||||
SRV_ERR("%s", "failed to create draft context\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
params_base.speculative.draft.ctx_tgt = ctx_tgt;
|
||||
params_base.speculative.draft.ctx_dft = ctx_dft.get();
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#include "build-info.h"
|
||||
#include "preset.h"
|
||||
#include "download.h"
|
||||
#include "http.h"
|
||||
|
||||
#include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
|
||||
#include <sheredom/subprocess.h>
|
||||
@@ -26,7 +25,14 @@
|
||||
#include <sstream>
|
||||
#include <cstring>
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef _WIN32
|
||||
#include <winsock2.h>
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <sys/socket.h>
|
||||
#include <netinet/in.h>
|
||||
#include <arpa/inet.h>
|
||||
#include <unistd.h>
|
||||
extern char **environ;
|
||||
#endif
|
||||
|
||||
@@ -698,6 +704,66 @@ std::optional<server_model_meta> server_models::get_meta(const std::string & nam
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
static int get_free_port() {
|
||||
#ifdef _WIN32
|
||||
WSADATA wsaData;
|
||||
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
|
||||
return -1;
|
||||
}
|
||||
typedef SOCKET native_socket_t;
|
||||
#define INVALID_SOCKET_VAL INVALID_SOCKET
|
||||
#define CLOSE_SOCKET(s) closesocket(s)
|
||||
#else
|
||||
typedef int native_socket_t;
|
||||
#define INVALID_SOCKET_VAL -1
|
||||
#define CLOSE_SOCKET(s) close(s)
|
||||
#endif
|
||||
|
||||
native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
|
||||
if (sock == INVALID_SOCKET_VAL) {
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct sockaddr_in serv_addr;
|
||||
std::memset(&serv_addr, 0, sizeof(serv_addr));
|
||||
serv_addr.sin_family = AF_INET;
|
||||
serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
|
||||
serv_addr.sin_port = htons(0);
|
||||
|
||||
if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
int namelen = sizeof(serv_addr);
|
||||
#else
|
||||
socklen_t namelen = sizeof(serv_addr);
|
||||
#endif
|
||||
if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
return -1;
|
||||
}
|
||||
|
||||
int port = ntohs(serv_addr.sin_port);
|
||||
|
||||
CLOSE_SOCKET(sock);
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
#endif
|
||||
|
||||
return port;
|
||||
}
|
||||
|
||||
// helper to convert vector<string> to char **
|
||||
// pointers are only valid as long as the original vector is valid
|
||||
static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
|
||||
@@ -801,7 +867,7 @@ void server_models::load(const std::string & name, const load_options & opts) {
|
||||
// prepare new instance info
|
||||
instance_t inst;
|
||||
inst.meta = meta;
|
||||
inst.meta.port = common_http_get_free_port();
|
||||
inst.meta.port = get_free_port();
|
||||
inst.meta.status = SERVER_MODEL_STATUS_LOADING;
|
||||
inst.meta.loaded_info = json{};
|
||||
inst.meta.last_used = ggml_time_ms();
|
||||
|
||||
+16
-35
@@ -35,19 +35,6 @@ static inline void signal_handler(int signal) {
|
||||
shutdown_handler(signal);
|
||||
}
|
||||
|
||||
// satisfies -Wmissing-declarations (used by llama command)
|
||||
int llama_server(int argc, char ** argv);
|
||||
|
||||
// to be used via CLI (argc / argv are used by router mode only)
|
||||
int llama_server(common_params & params, int argc, char ** argv);
|
||||
void llama_server_terminate();
|
||||
void llama_server_terminate() {
|
||||
if (shutdown_handler) {
|
||||
shutdown_handler(0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// wrapper function that handles exceptions and logs errors
|
||||
// this is to make sure handler_t never throws exceptions; instead, it returns an error response
|
||||
static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
|
||||
@@ -84,6 +71,9 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
|
||||
};
|
||||
}
|
||||
|
||||
// satisfies -Wmissing-declarations
|
||||
int llama_server(int argc, char ** argv);
|
||||
|
||||
int llama_server(int argc, char ** argv) {
|
||||
std::setlocale(LC_NUMERIC, "C");
|
||||
|
||||
@@ -99,14 +89,8 @@ int llama_server(int argc, char ** argv) {
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
return llama_server(params, argc, argv);
|
||||
}
|
||||
|
||||
int llama_server(common_params & params, int argc, char ** argv) {
|
||||
bool is_run_by_cli = (argv == nullptr);
|
||||
|
||||
// note: router mode also accepts -hf remote-preset, so we need to check that first
|
||||
if (!is_run_by_cli && !params.model.hf_repo.empty()) {
|
||||
if (!params.model.hf_repo.empty()) {
|
||||
try {
|
||||
common_params_handle_models_params handle_params;
|
||||
handle_params.preset_only = true;
|
||||
@@ -288,9 +272,8 @@ int llama_server(common_params & params, int argc, char ** argv) {
|
||||
|
||||
if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) {
|
||||
return child.run_download(params);
|
||||
} else if (!is_router_server && !is_run_by_cli) {
|
||||
} else if (!is_router_server) {
|
||||
// single-model mode (NOT spawned by router)
|
||||
// if this is invoked by CLI, model downloading should be already handled
|
||||
common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {});
|
||||
}
|
||||
|
||||
@@ -373,22 +356,20 @@ int llama_server(common_params & params, int argc, char ** argv) {
|
||||
};
|
||||
}
|
||||
|
||||
// register signal handler if not running by CLI
|
||||
if (!is_run_by_cli) {
|
||||
// TODO: refactor in common/console
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
sigaction(SIGTERM, &sigint_action, NULL);
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = signal_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
sigaction(SIGTERM, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (is_router_server) {
|
||||
SRV_INF("router server is listening on %s\n", ctx_http.listening_address.c_str());
|
||||
|
||||
+1
-2
@@ -28,10 +28,9 @@ vite.config.ts.timestamp-*
|
||||
# PWA Artifacts
|
||||
apple-splash-*.png
|
||||
apple-touch-icon-*.png
|
||||
favicon.ico
|
||||
favicon-dark.ico
|
||||
maskable-icon-*.png
|
||||
pwa-*.png
|
||||
static/favicon*
|
||||
|
||||
# Storybook
|
||||
*storybook.log
|
||||
|
||||
Generated
+7
-7
@@ -35,7 +35,7 @@
|
||||
"bits-ui": "2.18.1",
|
||||
"clsx": "2.1.1",
|
||||
"dexie": "4.4.3",
|
||||
"dompurify": "3.4.5",
|
||||
"dompurify": "3.4.11",
|
||||
"eslint": "9.39.4",
|
||||
"eslint-config-prettier": "10.1.8",
|
||||
"eslint-plugin-storybook": "10.4.2",
|
||||
@@ -8653,9 +8653,9 @@
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/dompurify": {
|
||||
"version": "3.4.5",
|
||||
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.5.tgz",
|
||||
"integrity": "sha512-OrwIBKsdNSVEeubdJ1HBv/wNENRM9ytAVCv7YXt//A3vPdVMNuACRqK9mXCGCBW2ln7BT/A4X0jXHo2Gu89miA==",
|
||||
"version": "3.4.11",
|
||||
"resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.11.tgz",
|
||||
"integrity": "sha512-zhlUV12GsaRzMsf9q5M254YhA4+VuF0fG+QFqu6aYpoGlKtz+w8//jBcGVYBgQkR5GHjUomejY84AV+/uPbWdw==",
|
||||
"dev": true,
|
||||
"license": "(MPL-2.0 OR Apache-2.0)",
|
||||
"optionalDependencies": {
|
||||
@@ -10226,9 +10226,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.23",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.23.tgz",
|
||||
"integrity": "sha512-eIaZ9qDgu7XV0pxOCrg7/WhnQ6Ivm22UcxhXx/A3dcbqbbYgBEkc6e/J/s7j2tS96zoB0S9VBdLwQNCWwUo4LA==",
|
||||
"version": "4.12.26",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.26.tgz",
|
||||
"integrity": "sha512-uyZtpnYxM9CmQ7QsQknM4zN8EftNqhON1qYeIKM0Se67CCEe2c44xyGURwB0axX2fBDu1dqHrHAc1hmNT8ITkw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
"bits-ui": "2.18.1",
|
||||
"clsx": "2.1.1",
|
||||
"dexie": "4.4.3",
|
||||
"dompurify": "3.4.5",
|
||||
"dompurify": "3.4.11",
|
||||
"eslint": "9.39.4",
|
||||
"eslint-config-prettier": "10.1.8",
|
||||
"eslint-plugin-storybook": "10.4.2",
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
import { defineConfig } from '@vite-pwa/assets-generator/config';
|
||||
import { FAVICON_COLORS, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
|
||||
import { writeThemeFavicons } from './scripts/favicon-colorize';
|
||||
|
||||
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
});
|
||||
|
||||
export default defineConfig({
|
||||
headLinkOptions: {
|
||||
@@ -7,7 +13,8 @@ export default defineConfig({
|
||||
preset: {
|
||||
transparent: {
|
||||
sizes: [],
|
||||
favicons: [[48, 'favicon-dark.ico']]
|
||||
favicons: [[48, 'favicon-dark.ico']],
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
},
|
||||
maskable: {
|
||||
sizes: []
|
||||
|
||||
@@ -5,15 +5,32 @@ import {
|
||||
} from '@vite-pwa/assets-generator/config';
|
||||
import { readFileSync } from 'node:fs';
|
||||
import { resolve } from 'node:path';
|
||||
import { THEME_COLORS, PWA_GENERATOR_DEVICES, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa';
|
||||
import {
|
||||
THEME_COLORS,
|
||||
PWA_GENERATOR_DEVICES,
|
||||
PWA_ASSET_GENERATOR,
|
||||
FAVICON_COLORS
|
||||
} from './src/lib/constants/pwa';
|
||||
import { SplashOrientation } from './src/lib/enums/splash.enums';
|
||||
import { writeThemeFavicons } from './scripts/favicon-colorize';
|
||||
|
||||
writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, {
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
});
|
||||
|
||||
export default defineConfig({
|
||||
headLinkOptions: {
|
||||
preset: PWA_ASSET_GENERATOR.LINK_PRESET
|
||||
},
|
||||
preset: combinePresetAndAppleSplashScreens(
|
||||
minimal2023Preset,
|
||||
{
|
||||
...minimal2023Preset,
|
||||
// tiny margin so favicon.ico / pwa-*.png breathe inside the canvas
|
||||
transparent: {
|
||||
...minimal2023Preset.transparent,
|
||||
padding: PWA_ASSET_GENERATOR.FAVICON_PADDING
|
||||
}
|
||||
},
|
||||
{
|
||||
padding: PWA_ASSET_GENERATOR.SPLASH_PADDING,
|
||||
resizeOptions: {
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
import { mkdirSync, readFileSync, writeFileSync } from 'node:fs';
|
||||
import { dirname, resolve } from 'node:path';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
|
||||
const HERE = dirname(fileURLToPath(import.meta.url));
|
||||
const PROJECT_ROOT = resolve(HERE, '..');
|
||||
|
||||
const DEFAULT_LOGO = resolve(PROJECT_ROOT, 'src/lib/assets/logo.svg');
|
||||
const DEFAULT_OUT_DIR = resolve(PROJECT_ROOT, 'static');
|
||||
const DEFAULT_OUT_LIGHT = resolve(DEFAULT_OUT_DIR, 'favicon.svg');
|
||||
const DEFAULT_OUT_DARK = resolve(DEFAULT_OUT_DIR, 'favicon-dark.svg');
|
||||
|
||||
const CURRENT_COLOR = 'currentColor';
|
||||
|
||||
export interface ColorizedFavicon {
|
||||
light: string;
|
||||
dark: string;
|
||||
}
|
||||
|
||||
export interface WriteThemeFaviconsOptions {
|
||||
sourcePath?: string;
|
||||
lightOutPath?: string;
|
||||
darkOutPath?: string;
|
||||
/**
|
||||
* Fraction of the icon (0..1) to leave as an even margin on each side.
|
||||
* Applied by wrapping the inner content in a `<g transform="...">` so the
|
||||
* source `src/lib/assets/logo.svg` is not modified. Pass 0 to disable.
|
||||
*/
|
||||
padding?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace every `currentColor` occurrence in the SVG with the given color.
|
||||
* Pure: no filesystem access, so it is straightforward to unit-test.
|
||||
*/
|
||||
export function colorizeFaviconSvg(
|
||||
svg: string,
|
||||
lightColor: string,
|
||||
darkColor: string
|
||||
): ColorizedFavicon {
|
||||
return {
|
||||
light: svg.replaceAll(CURRENT_COLOR, lightColor),
|
||||
dark: svg.replaceAll(CURRENT_COLOR, darkColor)
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Shrink the inner SVG content uniformly and re-center it so `padding` (a
|
||||
* 0..1 fraction) is reserved as equal margin on each side. Returns the input
|
||||
* unchanged for non-positive padding, missing/invalid `viewBox`, or unexpected
|
||||
* markup so the caller always gets a renderable SVG.
|
||||
*/
|
||||
export function padFaviconSvg(svg: string, padding: number): string {
|
||||
if (!(padding > 0) || padding >= 1) return svg;
|
||||
|
||||
const viewBoxMatch = svg.match(/viewBox\s*=\s*["']([^"']+)["']/i);
|
||||
if (!viewBoxMatch) return svg;
|
||||
|
||||
const parts = viewBoxMatch[1]
|
||||
.trim()
|
||||
.split(/[\s,]+/)
|
||||
.map(Number);
|
||||
if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) return svg;
|
||||
|
||||
const [, , width, height] = parts;
|
||||
if (width <= 0 || height <= 0) return svg;
|
||||
|
||||
const scale = 1 - padding;
|
||||
const translateX = (padding * width) / 2;
|
||||
const translateY = (padding * height) / 2;
|
||||
|
||||
const openTagStart = svg.search(/<svg\b/i);
|
||||
if (openTagStart === -1) return svg;
|
||||
const openTagEnd = svg.indexOf('>', openTagStart);
|
||||
if (openTagEnd === -1) return svg;
|
||||
const closeStart = svg.lastIndexOf('</svg');
|
||||
if (closeStart === -1 || closeStart <= openTagEnd) return svg;
|
||||
|
||||
const openTag = svg.slice(0, openTagEnd + 1);
|
||||
const inner = svg.slice(openTagEnd + 1, closeStart);
|
||||
const closeTag = svg.slice(closeStart);
|
||||
|
||||
const group = `<g transform="translate(${translateX} ${translateY}) scale(${scale})">`;
|
||||
return `${openTag}${group}${inner}</g>${closeTag}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read `src/lib/assets/logo.svg`, colorize it for both themes, and write
|
||||
* the results to the static directory so the PWA asset generator can consume
|
||||
* them. Paths can be overridden for tests.
|
||||
*/
|
||||
export function writeThemeFavicons(
|
||||
lightColor: string,
|
||||
darkColor: string,
|
||||
{
|
||||
sourcePath = DEFAULT_LOGO,
|
||||
lightOutPath = DEFAULT_OUT_LIGHT,
|
||||
darkOutPath = DEFAULT_OUT_DARK,
|
||||
padding = 0
|
||||
}: WriteThemeFaviconsOptions = {}
|
||||
): void {
|
||||
const source = readFileSync(sourcePath, 'utf-8');
|
||||
const { light, dark } = colorizeFaviconSvg(source, lightColor, darkColor);
|
||||
mkdirSync(dirname(lightOutPath), { recursive: true });
|
||||
writeFileSync(lightOutPath, padFaviconSvg(light, padding));
|
||||
writeFileSync(darkOutPath, padFaviconSvg(dark, padding));
|
||||
}
|
||||
@@ -48,6 +48,7 @@
|
||||
|
||||
--chat-form-area-height: 8rem;
|
||||
--chat-form-area-offset: 2rem;
|
||||
--chat-form-padding-top: 6rem;
|
||||
--max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem)));
|
||||
}
|
||||
|
||||
@@ -55,6 +56,7 @@
|
||||
:root {
|
||||
--chat-form-area-height: 24rem;
|
||||
--chat-form-area-offset: 12rem;
|
||||
--chat-form-padding-top: 6rem;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,7 +143,6 @@
|
||||
@apply bg-background text-foreground;
|
||||
scrollbar-width: thin;
|
||||
scrollbar-gutter: stable;
|
||||
overflow: hidden; /* Added due to Mermaid rendering somehow causing the double scrollbar */
|
||||
}
|
||||
|
||||
/* Global scrollbar styling - visible only on hover */
|
||||
@@ -193,3 +194,7 @@
|
||||
scrollbar-width: none;
|
||||
}
|
||||
}
|
||||
|
||||
.mermaidTooltip {
|
||||
display: none !important;
|
||||
}
|
||||
|
||||
@@ -10,9 +10,9 @@ import { isElementInViewport } from '$lib/utils/viewport';
|
||||
*/
|
||||
export function fadeInView(
|
||||
node: HTMLElement,
|
||||
options: { duration?: number; y?: number; skipIfVisible?: boolean } = {}
|
||||
options: { duration?: number; y?: number; delay?: number; skipIfVisible?: boolean } = {}
|
||||
) {
|
||||
const { duration = 300, y = 0, skipIfVisible = false } = options;
|
||||
const { duration = 300, y = 0, delay = 0, skipIfVisible = false } = options;
|
||||
|
||||
if (skipIfVisible && isElementInViewport(node)) {
|
||||
return;
|
||||
@@ -27,10 +27,12 @@ export function fadeInView(
|
||||
(entries) => {
|
||||
for (const entry of entries) {
|
||||
if (entry.isIntersecting) {
|
||||
requestAnimationFrame(() => {
|
||||
node.style.opacity = '1';
|
||||
node.style.transform = 'translateY(0)';
|
||||
});
|
||||
setTimeout(() => {
|
||||
requestAnimationFrame(() => {
|
||||
node.style.opacity = '1';
|
||||
node.style.transform = 'translateY(0)';
|
||||
});
|
||||
}, delay);
|
||||
observer.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M244.95 8C215.233 8 187.774 23.8591 172.923 49.5999L95.6009 183.625C60.2162 244.959 104.481 321.6 175.29 321.6H208L316.977 132.708C348.959 77.2719 308.95 8 244.95 8ZM208 321.6H351.947C415.982 321.6 456.013 390.91 424.013 446.377C409.155 472.132 381.681 488 351.947 488H271.29C200.481 488 156.216 411.359 191.601 350.026L208 321.6Z" fill="currentColor"/>
|
||||
<path d="M208 321.6H16L106.462 164.8L208 321.6Z" fill="currentColor"/>
|
||||
<path d="M388.923 8L208 321.6L253.6 8H388.923Z" fill="currentColor"/>
|
||||
<path d="M304 488H112L202.462 331.2L304 488Z" fill="currentColor"/>
|
||||
<path d="M496 321.6H208L419.399 454.4L496 321.6Z" fill="currentColor"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 771 B |
@@ -8,12 +8,13 @@
|
||||
ariaLabel?: string;
|
||||
class?: string;
|
||||
disabled?: boolean;
|
||||
href?: string;
|
||||
icon: Component;
|
||||
iconSize?: string;
|
||||
onclick: (e?: MouseEvent) => void;
|
||||
onclick?: (e?: MouseEvent) => void;
|
||||
size?: ButtonSize;
|
||||
stopPropagationOnClick?: boolean;
|
||||
tooltip: string;
|
||||
tooltip?: string;
|
||||
variant?: ButtonVariant;
|
||||
tooltipSide?: TooltipSide;
|
||||
}
|
||||
@@ -22,6 +23,7 @@
|
||||
icon,
|
||||
tooltip,
|
||||
variant = 'ghost',
|
||||
href = '',
|
||||
size = 'sm',
|
||||
class: className = '',
|
||||
disabled = false,
|
||||
@@ -31,34 +33,49 @@
|
||||
onclick,
|
||||
ariaLabel
|
||||
}: Props = $props();
|
||||
|
||||
let innerWidth = $state(0);
|
||||
const showTooltip = $derived(!!tooltip && innerWidth > 768);
|
||||
</script>
|
||||
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<!-- prevent another nested button element -->
|
||||
{#snippet child({ props })}
|
||||
<Button
|
||||
{...props}
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
onclick={(e: MouseEvent) => {
|
||||
if (stopPropagationOnClick) e.stopPropagation();
|
||||
{#snippet button(props = {})}
|
||||
<Button
|
||||
{...props}
|
||||
{href}
|
||||
{variant}
|
||||
{size}
|
||||
{disabled}
|
||||
onclick={(e: MouseEvent) => {
|
||||
if (stopPropagationOnClick) e.stopPropagation();
|
||||
|
||||
onclick?.(e);
|
||||
}}
|
||||
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{#if icon}
|
||||
{@const IconComponent = icon}
|
||||
<IconComponent class={iconSize} />
|
||||
{/if}
|
||||
</Button>
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
onclick?.(e);
|
||||
}}
|
||||
class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!"
|
||||
aria-label={ariaLabel || tooltip}
|
||||
>
|
||||
{#if icon}
|
||||
{@const IconComponent = icon}
|
||||
|
||||
<Tooltip.Content side={tooltipSide}>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
<IconComponent class={iconSize} />
|
||||
{/if}
|
||||
</Button>
|
||||
{/snippet}
|
||||
|
||||
{#if showTooltip}
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger>
|
||||
<!-- prevent another nested button element -->
|
||||
{#snippet child({ props })}
|
||||
{@render button(props)}
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
|
||||
<Tooltip.Content side={tooltipSide}>
|
||||
<p>{tooltip}</p>
|
||||
</Tooltip.Content>
|
||||
</Tooltip.Root>
|
||||
{:else}
|
||||
{@render button({ href })}
|
||||
{/if}
|
||||
|
||||
<svelte:window bind:innerWidth />
|
||||
|
||||
@@ -494,7 +494,7 @@
|
||||
/>
|
||||
|
||||
<div
|
||||
class="{INPUT_CLASSES} overflow-hidden rounded-3xl backdrop-blur-md {disabled
|
||||
class="{INPUT_CLASSES} overflow-hidden rounded-4xl md:rounded-3xl backdrop-blur-md {disabled
|
||||
? 'cursor-not-allowed opacity-60'
|
||||
: ''}"
|
||||
data-slot="input-area"
|
||||
@@ -510,7 +510,7 @@
|
||||
/>
|
||||
|
||||
<div
|
||||
class="flex-column relative min-h-[48px] items-center rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:!py-3"
|
||||
class="flex-column relative min-h-12 items-center rounded-4xl md:rounded-3xl py-2 pb-2.25 shadow-sm transition-all focus-within:shadow-md md:py-3!"
|
||||
onpaste={handlePaste}
|
||||
>
|
||||
<ChatFormTextarea
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@
|
||||
<Tooltip.Root>
|
||||
<Tooltip.Trigger class="w-full">
|
||||
<Button
|
||||
class="file-upload-button h-8 w-8 rounded-full p-0"
|
||||
class="file-upload-button md:h-8 md:w-8 h-9 w-9 rounded-full p-0"
|
||||
{disabled}
|
||||
{onclick}
|
||||
variant="secondary"
|
||||
|
||||
+16
-3
@@ -15,6 +15,7 @@
|
||||
import { McpLogo } from '$lib/components/app';
|
||||
import { PencilRuler, ChevronDown, ChevronRight } from '@lucide/svelte';
|
||||
import { HealthCheckStatus } from '$lib/enums';
|
||||
import { AttachmentAction } from '$lib/enums/attachment.enums';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -270,14 +271,22 @@
|
||||
</Collapsible.Root>
|
||||
{/if}
|
||||
|
||||
<button type="button" class={sheetItemClass} onclick={onSystemPromptClick}>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[AttachmentAction.SYSTEM_PROMPT_CLICK]()}
|
||||
>
|
||||
<MessageSquare class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>System Message</span>
|
||||
</button>
|
||||
|
||||
{#if hasMcpPromptsSupport}
|
||||
<button type="button" class={sheetItemClass} onclick={onMcpPromptClick}>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_PROMPT_CLICK]()}
|
||||
>
|
||||
<Zap class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>MCP Prompt</span>
|
||||
@@ -285,7 +294,11 @@
|
||||
{/if}
|
||||
|
||||
{#if hasMcpResourcesSupport}
|
||||
<button type="button" class={sheetItemClass} onclick={onMcpResourcesClick}>
|
||||
<button
|
||||
type="button"
|
||||
class={sheetItemClass}
|
||||
onclick={() => attachmentMenu.callbacks[AttachmentAction.MCP_RESOURCES_CLICK]()}
|
||||
>
|
||||
<FolderOpen class="h-4 w-4 shrink-0" />
|
||||
|
||||
<span>MCP Resources</span>
|
||||
|
||||
+1
@@ -42,6 +42,7 @@
|
||||
{hasMcpPromptsSupport}
|
||||
{hasMcpResourcesSupport}
|
||||
{onFileUpload}
|
||||
{onSystemPromptClick}
|
||||
{onMcpPromptClick}
|
||||
{onMcpResourcesClick}
|
||||
>
|
||||
|
||||
+1
-1
@@ -20,7 +20,7 @@
|
||||
type="submit"
|
||||
disabled={isDisabled}
|
||||
class={[
|
||||
'h-8 w-8 rounded-full p-0',
|
||||
'md:h-8 md:w-8 h-9 w-9 rounded-full p-0',
|
||||
showErrorState &&
|
||||
'bg-red-400/10 text-red-400 hover:bg-red-400/20 hover:text-red-400 disabled:opacity-100'
|
||||
]}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { autoResizeTextarea } from '$lib/utils';
|
||||
import { onMount } from 'svelte';
|
||||
|
||||
@@ -37,7 +38,9 @@
|
||||
}
|
||||
|
||||
export function focus() {
|
||||
textareaElement?.focus();
|
||||
if (isMobile.current) return;
|
||||
|
||||
textareaElement?.focus({ preventScroll: true });
|
||||
}
|
||||
|
||||
export function resetHeight() {
|
||||
|
||||
@@ -231,7 +231,7 @@
|
||||
editedContent = message.content;
|
||||
}
|
||||
|
||||
textareaElement?.focus();
|
||||
textareaElement?.focus({ preventScroll: true });
|
||||
editedExtras = message.extra ? [...message.extra] : [];
|
||||
editedUploadedFiles = [];
|
||||
|
||||
@@ -324,7 +324,7 @@
|
||||
}
|
||||
</script>
|
||||
|
||||
<div use:fadeInView>
|
||||
<div use:fadeInView class="chat-message">
|
||||
{#if message.role === MessageRole.SYSTEM}
|
||||
<ChatMessageSystem
|
||||
bind:textareaElement
|
||||
|
||||
+72
-5
@@ -180,6 +180,9 @@
|
||||
|
||||
let displayedModel = $derived(message.model ?? null);
|
||||
|
||||
// model being switched to while it loads, so the selector bar tracks it
|
||||
let pendingModel = $state<string | null>(null);
|
||||
|
||||
let isCurrentlyLoading = $derived(isLoading());
|
||||
let isStreaming = $derived(isChatStreaming());
|
||||
let hasNoContent = $derived(!message?.content?.trim());
|
||||
@@ -207,6 +210,42 @@
|
||||
isLastAssistantMessage
|
||||
);
|
||||
|
||||
let assistantEl: HTMLDivElement | undefined = $state();
|
||||
let lastUserMessageHeight = $state(0);
|
||||
let assistantMarginTop = $state(0);
|
||||
|
||||
$effect(() => {
|
||||
if (!assistantEl) return;
|
||||
|
||||
assistantMarginTop = Math.round(parseFloat(getComputedStyle(assistantEl).marginTop));
|
||||
|
||||
const chatMessageEl = assistantEl.closest('.chat-message');
|
||||
const previousChatMessage = chatMessageEl?.previousElementSibling;
|
||||
const userMessageEl = previousChatMessage?.querySelector(
|
||||
'.chat-message-user'
|
||||
) as HTMLElement | null;
|
||||
|
||||
if (!userMessageEl) {
|
||||
lastUserMessageHeight = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const updateHeight = () => {
|
||||
const rect = userMessageEl.getBoundingClientRect();
|
||||
const marginTop = Math.round(parseFloat(getComputedStyle(userMessageEl).marginTop));
|
||||
lastUserMessageHeight = Math.round(rect.height + marginTop);
|
||||
};
|
||||
|
||||
updateHeight();
|
||||
|
||||
const resizeObserver = new ResizeObserver(updateHeight);
|
||||
resizeObserver.observe(userMessageEl);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
});
|
||||
|
||||
function handleCopyModel() {
|
||||
void copyToClipboard(displayedModel ?? '');
|
||||
}
|
||||
@@ -219,12 +258,17 @@
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="text-md group w-full leading-7.5 {className}"
|
||||
bind:this={assistantEl}
|
||||
class="chat-message-assistant text-md group w-full leading-7.5 {className}"
|
||||
style:--last-user-message-height={lastUserMessageHeight > 0
|
||||
? `${lastUserMessageHeight}px`
|
||||
: undefined}
|
||||
style:--assistant-margin-top={assistantMarginTop > 0 ? `${assistantMarginTop}px` : undefined}
|
||||
role="group"
|
||||
aria-label="Assistant message with actions"
|
||||
>
|
||||
{#if showProcessingInfoTop}
|
||||
<div class="mt-6 w-full max-w-[48rem]" in:fade>
|
||||
<div class="mt-6 w-full max-w-3xl" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{modelLoadingText ??
|
||||
@@ -257,7 +301,7 @@
|
||||
{/if}
|
||||
|
||||
{#if showProcessingInfoBottom}
|
||||
<div class="mt-4 w-full max-w-[48rem]" in:fade>
|
||||
<div class="mt-4 w-full max-w-3xl" in:fade>
|
||||
<div class="processing-container">
|
||||
<span class="processing-text">
|
||||
{modelLoadingText ??
|
||||
@@ -277,13 +321,19 @@
|
||||
>
|
||||
{#if isRouter}
|
||||
<ModelsSelectorDropdown
|
||||
currentModel={displayedModel}
|
||||
currentModel={pendingModel ?? displayedModel}
|
||||
disabled={isLoading()}
|
||||
onModelChange={async (modelId: string, modelName: string) => {
|
||||
const status = modelsStore.getModelStatus(modelId);
|
||||
|
||||
if (status !== ServerModelStatus.LOADED) {
|
||||
await modelsStore.loadModel(modelId);
|
||||
pendingModel = modelId;
|
||||
|
||||
try {
|
||||
await modelsStore.loadModel(modelId);
|
||||
} finally {
|
||||
pendingModel = null;
|
||||
}
|
||||
}
|
||||
|
||||
onRegenerate(modelName);
|
||||
@@ -351,6 +401,23 @@
|
||||
</div>
|
||||
|
||||
<style>
|
||||
:global(.chat-message):last-child .chat-message-assistant {
|
||||
--assistant-min-height-offset: calc(
|
||||
var(--last-user-message-height, 19rem) + var(--chat-form-height, 6rem) +
|
||||
var(--chat-form-bottom-position, 0.5rem) + var(--chat-form-padding-top, 6rem) +
|
||||
var(--assistant-margin-top, 3rem)
|
||||
);
|
||||
min-height: calc(100dvh - var(--assistant-min-height-offset));
|
||||
|
||||
@media (width > 768px) {
|
||||
--assistant-min-height-offset: calc(
|
||||
var(--last-user-message-height, 18rem) + var(--chat-form-height, 6rem) +
|
||||
var(--chat-form-bottom-position, 1rem) + var(--chat-form-padding-top, 6rem) +
|
||||
var(--assistant-margin-top, 3rem)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
.processing-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
+1
-1
@@ -48,7 +48,7 @@
|
||||
|
||||
<div
|
||||
aria-label="User message with actions"
|
||||
class="group flex flex-col items-end gap-3 md:gap-2 {className}"
|
||||
class="chat-message-user group flex flex-col items-end gap-3 md:gap-2 {className}"
|
||||
role="group"
|
||||
>
|
||||
{#if editCtx.isEditing}
|
||||
|
||||
+2
-2
@@ -19,7 +19,7 @@
|
||||
renderMarkdown = false,
|
||||
textColorClass = 'text-foreground',
|
||||
cardBgClass = 'dark:bg-primary/15',
|
||||
maxHeightStyle = 'max-height: var(--max-message-height);'
|
||||
maxHeightStyle = ''
|
||||
}: Props = $props();
|
||||
|
||||
let isMultiline = $state(false);
|
||||
@@ -59,7 +59,7 @@
|
||||
|
||||
{#if content.trim()}
|
||||
<Card
|
||||
class="max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-[multiline]:py-2.5 {cardBgClass}"
|
||||
class="chat-message-user-bubble max-w-[80%] overflow-y-auto rounded-[1.125rem] border-none bg-primary/5 px-3.75 py-1.5 {textColorClass} backdrop-blur-md data-multiline:py-2.5 {cardBgClass}"
|
||||
data-multiline={isMultiline ? '' : undefined}
|
||||
style="{maxHeightStyle} overflow-wrap: anywhere; word-break: break-word;"
|
||||
>
|
||||
|
||||
@@ -37,6 +37,7 @@
|
||||
let allConversationMessages = $state<DatabaseMessage[]>([]);
|
||||
let isVisible = $state(false);
|
||||
let previousConversationId = $state<string | null>(null);
|
||||
let previousRouteId = $state<string | null>(null);
|
||||
|
||||
const currentConfig = config();
|
||||
|
||||
@@ -157,8 +158,9 @@
|
||||
});
|
||||
});
|
||||
|
||||
beforeNavigate(() => {
|
||||
beforeNavigate((navigation) => {
|
||||
isVisible = false;
|
||||
previousRouteId = navigation.from?.route.id ?? null;
|
||||
});
|
||||
|
||||
afterNavigate(() => {
|
||||
@@ -249,12 +251,13 @@
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="transition-opacity delay-300 duration-500 ease-out
|
||||
{isVisible ? 'opacity-100' : 'opacity-0'}"
|
||||
class="transition-opacity duration-500 ease-out
|
||||
{isVisible ? 'opacity-100' : 'opacity-0'}
|
||||
{previousRouteId === '/(chat)/chat/[id]' ? '' : 'delay-300'}"
|
||||
>
|
||||
{#each displayMessages as { message, toolMessages, isLastAssistantMessage, siblingInfo } (message.id)}
|
||||
<ChatMessage
|
||||
class="mx-auto mt-12 w-full max-w-[48rem]"
|
||||
class="mx-auto mt-12 w-full max-w-3xl"
|
||||
{message}
|
||||
{toolMessages}
|
||||
{isLastAssistantMessage}
|
||||
|
||||
@@ -1,31 +1,28 @@
|
||||
<script lang="ts">
|
||||
import { Trash2 } from '@lucide/svelte';
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import {
|
||||
ChatScreenForm,
|
||||
ChatMessages,
|
||||
ChatScreenDragOverlay,
|
||||
ChatScreenProcessingInfo,
|
||||
ChatScreenActionScrollDown,
|
||||
DialogEmptyFileAlert,
|
||||
DialogFileUploadError,
|
||||
DialogChatError,
|
||||
ServerLoadingSplash,
|
||||
DialogConfirmation,
|
||||
ChatScreenServerError
|
||||
} from '$lib/components/app';
|
||||
import { setProcessingInfoContext } from '$lib/contexts';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte';
|
||||
import { useChatScreenActiveModel } from '$lib/hooks/use-chat-screen-active-model.svelte';
|
||||
import { useChatScreenDragAndDrop } from '$lib/hooks/use-chat-screen-drag-and-drop.svelte';
|
||||
import { useChatScreenFileUpload } from '$lib/hooks/use-chat-screen-file-upload.svelte';
|
||||
import { useChatScreenScroll } from '$lib/hooks/use-chat-screen-scroll.svelte';
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
import { device } from '$lib/stores/device.svelte';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import {
|
||||
chatStore,
|
||||
errorDialog,
|
||||
isLoading,
|
||||
isChatStreaming,
|
||||
isEditing,
|
||||
getAddFilesHandler,
|
||||
activeProcessingState
|
||||
} from '$lib/stores/chat.svelte';
|
||||
import {
|
||||
@@ -34,138 +31,81 @@
|
||||
activeConversation
|
||||
} from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { serverLoading, serverError, isRouterMode } from '$lib/stores/server.svelte';
|
||||
import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte';
|
||||
import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils';
|
||||
import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only';
|
||||
import { onMount } from 'svelte';
|
||||
import { serverLoading, serverError } from '$lib/stores/server.svelte';
|
||||
import { parseFilesToMessageExtras } from '$lib/utils/browser-only';
|
||||
import { onDestroy, onMount } from 'svelte';
|
||||
import ChatScreenGreeting from './ChatScreenGreeting.svelte';
|
||||
import ChatScreenActionScrollDown from './ChatScreenActionScrollDown.svelte';
|
||||
import ChatScreenDialogsAndAlerts from './ChatScreenDialogsAndAlerts.svelte';
|
||||
import { ROUTES } from '$lib/constants';
|
||||
|
||||
let { showCenteredEmpty = false } = $props();
|
||||
|
||||
const autoScroll = createAutoScrollController();
|
||||
|
||||
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll));
|
||||
let chatScrollContainer: HTMLDivElement | undefined = $state();
|
||||
let dragCounter = $state(0);
|
||||
let isDragOver = $state(false);
|
||||
let showFileErrorDialog = $state(false);
|
||||
let uploadedFiles = $state<ChatUploadedFile[]>([]);
|
||||
|
||||
let fileErrorData = $state<{
|
||||
generallyUnsupported: File[];
|
||||
modalityUnsupported: File[];
|
||||
modalityReasons: Record<string, string>;
|
||||
supportedTypes: string[];
|
||||
}>({
|
||||
generallyUnsupported: [],
|
||||
modalityUnsupported: [],
|
||||
modalityReasons: {},
|
||||
supportedTypes: []
|
||||
});
|
||||
|
||||
let showDeleteDialog = $state(false);
|
||||
|
||||
let showEmptyFileDialog = $state(false);
|
||||
|
||||
let processingInfoVisible = $state(false);
|
||||
|
||||
let emptyFileNames = $state<string[]>([]);
|
||||
|
||||
let initialMessage = $state('');
|
||||
|
||||
let isEmpty = $derived(
|
||||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
|
||||
let activeErrorDialog = $derived(errorDialog());
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
let hasPropsError = $derived(!!serverError());
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
|
||||
|
||||
let showProcessingInfo = $derived(
|
||||
isCurrentConversationLoading ||
|
||||
(config().keepStatsVisible && !!page.params.id) ||
|
||||
activeProcessingState() !== null
|
||||
);
|
||||
|
||||
let isRouter = $derived(isRouterMode());
|
||||
|
||||
let conversationModel = $derived(
|
||||
chatStore.getConversationModel(activeMessages() as DatabaseMessage[])
|
||||
);
|
||||
|
||||
let activeModelId = $derived.by(() => {
|
||||
const options = modelOptions();
|
||||
|
||||
if (!isRouter) {
|
||||
return options.length > 0 ? options[0].model : null;
|
||||
}
|
||||
|
||||
const selectedId = selectedModelId();
|
||||
if (selectedId) {
|
||||
const model = options.find((m) => m.id === selectedId);
|
||||
if (model) return model.model;
|
||||
}
|
||||
|
||||
if (conversationModel) {
|
||||
const model = options.find((m) => m.model === conversationModel);
|
||||
if (model) return model.model;
|
||||
}
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
let modelPropsVersion = $state(0);
|
||||
|
||||
setProcessingInfoContext({
|
||||
get showProcessingInfo() {
|
||||
return showProcessingInfo;
|
||||
}
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
if (activeModelId) {
|
||||
const cached = modelsStore.getModelProps(activeModelId);
|
||||
let disableAutoScroll = $derived(Boolean(config().disableAutoScroll) || isMobile.current);
|
||||
let isMobileUserScrolledUp = $state(false);
|
||||
let mobileScrollDownHint = $state(false);
|
||||
let mobileScrollDownHintLockedUntil = $state(0);
|
||||
let emptyFileNames = $state<string[]>([]);
|
||||
let initialMessage = $state('');
|
||||
let showDeleteDialog = $state(false);
|
||||
let showEmptyFileDialog = $state(false);
|
||||
let isEmpty = $derived(
|
||||
showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading()
|
||||
);
|
||||
let activeErrorDialog = $derived(errorDialog());
|
||||
let isServerLoading = $derived(serverLoading());
|
||||
let hasPropsError = $derived(!!serverError());
|
||||
let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming());
|
||||
let showProcessingInfo = $derived(
|
||||
isCurrentConversationLoading ||
|
||||
(config().keepStatsVisible && !!page.params.id) ||
|
||||
activeProcessingState() !== null
|
||||
);
|
||||
let chatFormBottomPosition = $derived.by(() => {
|
||||
if (!isMobile.current) return '1rem';
|
||||
if (device.isStandalone) return '1.5rem';
|
||||
if (device.isIOSSafari) return '0.25rem';
|
||||
return '0.5rem';
|
||||
});
|
||||
|
||||
if (!cached) {
|
||||
modelsStore.fetchModelProps(activeModelId).then(() => {
|
||||
modelPropsVersion++;
|
||||
});
|
||||
const autoScroll = createAutoScrollController();
|
||||
const scroll = useChatScreenScroll(autoScroll);
|
||||
const activeModel = useChatScreenActiveModel();
|
||||
const fileUpload = useChatScreenFileUpload({
|
||||
capabilities: () => ({
|
||||
hasVision: activeModel.hasVisionModality,
|
||||
hasAudio: activeModel.hasAudioModality,
|
||||
hasVideo: activeModel.hasVideoModality
|
||||
}),
|
||||
activeModelId: () => activeModel.activeModelId
|
||||
});
|
||||
const dragAndDrop = useChatScreenDragAndDrop({
|
||||
onDrop: fileUpload.handleFileUpload
|
||||
});
|
||||
const { handleKeydown } = useKeyboardShortcuts({
|
||||
deleteActiveConversation: () => {
|
||||
if (activeConversation()) {
|
||||
showDeleteDialog = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let hasAudioModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
function handleMobileScroll() {
|
||||
if (!isMobile.current) return;
|
||||
|
||||
return modelsStore.modelSupportsAudio(activeModelId);
|
||||
}
|
||||
const container = scroll.chatScrollContainer;
|
||||
if (!container) return;
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
let hasVideoModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
|
||||
return modelsStore.modelSupportsVideo(activeModelId);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
let hasVisionModality = $derived.by(() => {
|
||||
if (activeModelId) {
|
||||
void modelPropsVersion;
|
||||
|
||||
return modelsStore.modelSupportsVision(activeModelId);
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
const distanceFromBottom =
|
||||
container.scrollHeight - container.clientHeight - container.scrollTop;
|
||||
isMobileUserScrolledUp = distanceFromBottom > 300;
|
||||
}
|
||||
|
||||
async function handleDeleteConfirm() {
|
||||
const conversation = activeConversation();
|
||||
@@ -177,27 +117,69 @@
|
||||
showDeleteDialog = false;
|
||||
}
|
||||
|
||||
function handleProcessingInfoVisibility(visible: boolean) {
|
||||
processingInfoVisible = visible;
|
||||
}
|
||||
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
|
||||
const plainFiles = files ? $state.snapshot(files) : undefined;
|
||||
const result = plainFiles
|
||||
? await parseFilesToMessageExtras(plainFiles, activeModel.activeModelId ?? undefined)
|
||||
: undefined;
|
||||
|
||||
function handleDragEnter(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
|
||||
dragCounter++;
|
||||
|
||||
if (event.dataTransfer?.types.includes('Files')) {
|
||||
isDragOver = true;
|
||||
if (result?.emptyFiles && result.emptyFiles.length > 0) {
|
||||
emptyFileNames = result.emptyFiles;
|
||||
showEmptyFileDialog = true;
|
||||
if (files) {
|
||||
const emptyFileNamesSet = new Set(result.emptyFiles);
|
||||
fileUpload.uploadedFiles = fileUpload.uploadedFiles.filter(
|
||||
(file) => !emptyFileNamesSet.has(file.name)
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
handleSendLikeScroll();
|
||||
|
||||
await chatStore.sendMessage(message, result?.extras);
|
||||
return true;
|
||||
}
|
||||
|
||||
function handleDragLeave(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
function handleSendLikeScroll() {
|
||||
if (!isMobile.current) {
|
||||
autoScroll.enable();
|
||||
}
|
||||
|
||||
dragCounter--;
|
||||
setTimeout(() => {
|
||||
const container = scroll.chatScrollContainer;
|
||||
if (!container) return;
|
||||
|
||||
if (dragCounter === 0) {
|
||||
isDragOver = false;
|
||||
const lastUserBubble = container.querySelector(
|
||||
'.chat-message:nth-last-child(2) .chat-message-user .chat-message-user-bubble'
|
||||
) as HTMLElement | null;
|
||||
|
||||
if (isMobile.current) {
|
||||
// Keep the last user message bubble just above the input on mobile
|
||||
const bubbleHeight = lastUserBubble?.scrollHeight ?? 0;
|
||||
const baseHeight = container.scrollHeight - innerHeight;
|
||||
|
||||
container.scrollTo({
|
||||
top: bubbleHeight > 0 ? baseHeight - bubbleHeight : baseHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
} else if (lastUserBubble) {
|
||||
// On desktop, place the last user message near the top of the viewport
|
||||
const topPadding = 24;
|
||||
const bubbleRect = lastUserBubble.getBoundingClientRect();
|
||||
container.scrollTo({
|
||||
top: Math.max(0, container.scrollTop + bubbleRect.top - topPadding),
|
||||
behavior: 'smooth'
|
||||
});
|
||||
} else {
|
||||
autoScroll.scrollToBottom();
|
||||
}
|
||||
}, 100);
|
||||
|
||||
if (isMobile.current) {
|
||||
autoScroll.setDisabled(disableAutoScroll);
|
||||
mobileScrollDownHint = true;
|
||||
mobileScrollDownHintLockedUntil = Date.now() + 500;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,273 +189,138 @@
|
||||
}
|
||||
}
|
||||
|
||||
function handleDragOver(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
}
|
||||
|
||||
function handleDrop(event: DragEvent) {
|
||||
event.preventDefault();
|
||||
|
||||
isDragOver = false;
|
||||
dragCounter = 0;
|
||||
|
||||
if (event.dataTransfer?.files) {
|
||||
const files = Array.from(event.dataTransfer.files);
|
||||
|
||||
if (isEditing()) {
|
||||
const handler = getAddFilesHandler();
|
||||
|
||||
if (handler) {
|
||||
handler(files);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
processFiles(files);
|
||||
}
|
||||
}
|
||||
|
||||
function handleFileRemove(fileId: string) {
|
||||
uploadedFiles = uploadedFiles.filter((f) => f.id !== fileId);
|
||||
}
|
||||
|
||||
function handleFileUpload(files: File[]) {
|
||||
processFiles(files);
|
||||
}
|
||||
|
||||
const { handleKeydown } = useKeyboardShortcuts({
|
||||
deleteActiveConversation: () => {
|
||||
if (activeConversation()) {
|
||||
showDeleteDialog = true;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) {
|
||||
if (draft.message || draft.files.length > 0) {
|
||||
chatStore.savePendingDraft(draft.message, draft.files);
|
||||
}
|
||||
|
||||
await chatStore.addSystemPrompt();
|
||||
}
|
||||
|
||||
function handleScroll() {
|
||||
autoScroll.handleScroll();
|
||||
}
|
||||
|
||||
async function handleSendMessage(message: string, files?: ChatUploadedFile[]): Promise<boolean> {
|
||||
const plainFiles = files ? $state.snapshot(files) : undefined;
|
||||
const result = plainFiles
|
||||
? await parseFilesToMessageExtras(plainFiles, activeModelId ?? undefined)
|
||||
: undefined;
|
||||
|
||||
if (result?.emptyFiles && result.emptyFiles.length > 0) {
|
||||
emptyFileNames = result.emptyFiles;
|
||||
showEmptyFileDialog = true;
|
||||
|
||||
if (files) {
|
||||
const emptyFileNamesSet = new Set(result.emptyFiles);
|
||||
uploadedFiles = uploadedFiles.filter((file) => !emptyFileNamesSet.has(file.name));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const extras = result?.extras;
|
||||
|
||||
// Enable autoscroll for user-initiated message sending
|
||||
autoScroll.enable();
|
||||
await chatStore.sendMessage(message, extras);
|
||||
autoScroll.scrollToBottom();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
async function processFiles(files: File[]) {
|
||||
const generallySupported: File[] = [];
|
||||
const generallyUnsupported: File[] = [];
|
||||
|
||||
for (const file of files) {
|
||||
if (isFileTypeSupported(file.name, file.type)) {
|
||||
generallySupported.push(file);
|
||||
} else {
|
||||
generallyUnsupported.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
// Use model-specific capabilities for file validation
|
||||
const capabilities = {
|
||||
hasVision: hasVisionModality,
|
||||
hasAudio: hasAudioModality,
|
||||
hasVideo: hasVideoModality
|
||||
};
|
||||
const { supportedFiles, unsupportedFiles, modalityReasons } = filterFilesByModalities(
|
||||
generallySupported,
|
||||
capabilities
|
||||
);
|
||||
|
||||
const allUnsupportedFiles = [...generallyUnsupported, ...unsupportedFiles];
|
||||
|
||||
if (allUnsupportedFiles.length > 0) {
|
||||
const supportedTypes: string[] = ['text files', 'PDFs'];
|
||||
|
||||
if (hasVisionModality) supportedTypes.push('images');
|
||||
if (hasAudioModality) supportedTypes.push('audio files');
|
||||
if (hasVideoModality) supportedTypes.push('video files');
|
||||
|
||||
fileErrorData = {
|
||||
generallyUnsupported,
|
||||
modalityUnsupported: unsupportedFiles,
|
||||
modalityReasons,
|
||||
supportedTypes
|
||||
};
|
||||
showFileErrorDialog = true;
|
||||
}
|
||||
|
||||
if (supportedFiles.length > 0) {
|
||||
const processed = await processFilesToChatUploaded(
|
||||
supportedFiles,
|
||||
activeModelId ?? undefined
|
||||
);
|
||||
uploadedFiles = [...uploadedFiles, ...processed];
|
||||
}
|
||||
}
|
||||
|
||||
afterNavigate(() => {
|
||||
if (!disableAutoScroll) {
|
||||
$effect(() => {
|
||||
const shouldDisableAutoScroll =
|
||||
config().disableAutoScroll || (isMobile.current && isCurrentConversationLoading);
|
||||
autoScroll.setDisabled(shouldDisableAutoScroll);
|
||||
if (!shouldDisableAutoScroll) {
|
||||
autoScroll.enable();
|
||||
}
|
||||
});
|
||||
|
||||
function handleMessagesReady() {
|
||||
if (disableAutoScroll) return;
|
||||
|
||||
if (!autoScroll.userScrolledUp) {
|
||||
requestAnimationFrame(() => {
|
||||
autoScroll.scrollToBottom('instant');
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
const pendingDraft = chatStore.consumePendingDraft();
|
||||
if (pendingDraft) {
|
||||
initialMessage = pendingDraft.message;
|
||||
fileUpload.uploadedFiles = pendingDraft.files;
|
||||
}
|
||||
|
||||
autoScroll.startObserving();
|
||||
|
||||
if (!disableAutoScroll) {
|
||||
autoScroll.enable();
|
||||
}
|
||||
|
||||
const pendingDraft = chatStore.consumePendingDraft();
|
||||
if (pendingDraft) {
|
||||
initialMessage = pendingDraft.message;
|
||||
uploadedFiles = pendingDraft.files;
|
||||
if (isMobile.current && isCurrentConversationLoading) {
|
||||
mobileScrollDownHint = true;
|
||||
mobileScrollDownHintLockedUntil = Date.now() + 500;
|
||||
}
|
||||
|
||||
handleMobileScroll();
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
autoScroll.setContainer(chatScrollContainer);
|
||||
});
|
||||
|
||||
$effect(() => {
|
||||
autoScroll.setDisabled(disableAutoScroll);
|
||||
});
|
||||
onDestroy(() => autoScroll.destroy());
|
||||
</script>
|
||||
|
||||
{#if isDragOver}
|
||||
{#if dragAndDrop.isDragOver}
|
||||
<ChatScreenDragOverlay />
|
||||
{/if}
|
||||
|
||||
<svelte:window onkeydown={handleKeydown} />
|
||||
<svelte:window
|
||||
onkeydown={handleKeydown}
|
||||
onscroll={(e) => {
|
||||
scroll.handleScroll(e);
|
||||
handleMobileScroll();
|
||||
if (e.isTrusted && Date.now() > mobileScrollDownHintLockedUntil) {
|
||||
mobileScrollDownHint = false;
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
{#if isServerLoading}
|
||||
<ServerLoadingSplash />
|
||||
{:else}
|
||||
<div
|
||||
bind:this={chatScrollContainer}
|
||||
aria-label="Chat interface with file drop zone"
|
||||
class="flex h-full flex-col overflow-y-auto px-4 md:px-6"
|
||||
ondragenter={handleDragEnter}
|
||||
ondragleave={handleDragLeave}
|
||||
ondragover={handleDragOver}
|
||||
ondrop={handleDrop}
|
||||
onscroll={handleScroll}
|
||||
class="chat-screen flex grow flex-col min-h-[calc(100dvh-1rem)] md:min-h-full px-4 md:py-0 pt-12 pb-48 md:pb-4"
|
||||
style:--chat-form-bottom-position={chatFormBottomPosition}
|
||||
ondragenter={dragAndDrop.dragHandlers.dragenter}
|
||||
ondragleave={dragAndDrop.dragHandlers.dragleave}
|
||||
ondragover={dragAndDrop.dragHandlers.dragover}
|
||||
ondrop={dragAndDrop.dragHandlers.drop}
|
||||
role="main"
|
||||
>
|
||||
<div class="flex grow flex-col pt-14">
|
||||
{#if !isEmpty}
|
||||
<ChatMessages
|
||||
messages={activeMessages()}
|
||||
onMessagesReady={handleMessagesReady}
|
||||
onUserAction={() => {
|
||||
autoScroll.enable();
|
||||
if (!autoScroll.userScrolledUp) {
|
||||
autoScroll.scrollToBottom();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
{#if !isEmpty}
|
||||
<ChatMessages
|
||||
messages={activeMessages()}
|
||||
onUserAction={() => {
|
||||
handleSendLikeScroll();
|
||||
}}
|
||||
/>
|
||||
{/if}
|
||||
|
||||
<div
|
||||
class={[
|
||||
'pointer-events-none sticky right-4 left-4 mt-auto transition-all duration-200',
|
||||
isEmpty ? 'bottom-[calc(50dvh-7rem)]' : 'bottom-4 pt-24 md:pt-32'
|
||||
]}
|
||||
>
|
||||
<ChatScreenGreeting {isEmpty} />
|
||||
<div
|
||||
class={[
|
||||
'pointer-events-none md:sticky fixed mt-auto transition-all duration-200',
|
||||
device.isStandalone
|
||||
? 'bottom-6 right-4 left-4'
|
||||
: device.isIOSSafari
|
||||
? 'bottom-1 left-2 right-2'
|
||||
: 'bottom-2 right-2 left-2',
|
||||
isEmpty ? 'md:bottom-[calc(50dvh-7rem)] 2xl:bottom-[calc(50dvh-4rem)]' : 'md:bottom-4'
|
||||
]}
|
||||
style:padding-top={!isEmpty ? 'var(--chat-form-padding-top)' : undefined}
|
||||
>
|
||||
<ChatScreenGreeting {isEmpty} />
|
||||
|
||||
<ChatScreenActionScrollDown
|
||||
container={chatScrollContainer}
|
||||
hasProcessingInfoVisible={processingInfoVisible}
|
||||
/>
|
||||
<ChatScreenServerError />
|
||||
|
||||
<ChatScreenProcessingInfo onVisibilityChange={handleProcessingInfoVisibility} />
|
||||
|
||||
<ChatScreenServerError />
|
||||
|
||||
<div class="conversation-chat-form pointer-events-auto rounded-t-3xl">
|
||||
<ChatScreenForm
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={handleFileRemove}
|
||||
onFileUpload={handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
bind:uploadedFiles
|
||||
<div class="pointer-events-none flex flex-col gap-6 items-center w-full">
|
||||
{#if (isMobile.current ? mobileScrollDownHint || isMobileUserScrolledUp : autoScroll.userScrolledUp) && page.url.hash.includes(ROUTES.CHAT) && page.params.id}
|
||||
<ChatScreenActionScrollDown
|
||||
onclick={() => {
|
||||
mobileScrollDownHint = false;
|
||||
scroll.chatScrollContainer?.scrollTo({
|
||||
top: scroll.chatScrollContainer.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
{#if showProcessingInfo}
|
||||
<ChatScreenProcessingInfo />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<ChatScreenForm
|
||||
class="pointer-events-auto conversation-chat-form"
|
||||
disabled={hasPropsError || isEditing()}
|
||||
{initialMessage}
|
||||
isLoading={isCurrentConversationLoading}
|
||||
onFileRemove={fileUpload.handleFileRemove}
|
||||
onFileUpload={fileUpload.handleFileUpload}
|
||||
onSend={handleSendMessage}
|
||||
onStop={() => chatStore.stopGeneration()}
|
||||
onSystemPromptAdd={handleSystemPromptAdd}
|
||||
bind:uploadedFiles={fileUpload.uploadedFiles}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<DialogFileUploadError bind:open={showFileErrorDialog} {fileErrorData} />
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
|
||||
confirmText="Delete"
|
||||
cancelText="Cancel"
|
||||
variant="destructive"
|
||||
icon={Trash2}
|
||||
onConfirm={handleDeleteConfirm}
|
||||
onCancel={() => (showDeleteDialog = false)}
|
||||
/>
|
||||
|
||||
<DialogEmptyFileAlert
|
||||
bind:open={showEmptyFileDialog}
|
||||
emptyFiles={emptyFileNames}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
emptyFileNames = [];
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<DialogChatError
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
contextInfo={activeErrorDialog?.contextInfo}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
|
||||
<ChatScreenDialogsAndAlerts
|
||||
{showDeleteDialog}
|
||||
{handleDeleteConfirm}
|
||||
{showEmptyFileDialog}
|
||||
{emptyFileNames}
|
||||
{activeErrorDialog}
|
||||
{handleErrorDialogOpenChange}
|
||||
{fileUpload}
|
||||
/>
|
||||
|
||||
@@ -1,58 +1,18 @@
|
||||
<script lang="ts">
|
||||
import { ArrowDown } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import ActionIcon from '$lib/components/app/actions/ActionIcon.svelte';
|
||||
|
||||
interface Props {
|
||||
container: HTMLDivElement | undefined;
|
||||
hasProcessingInfoVisible: boolean;
|
||||
}
|
||||
|
||||
let { container, hasProcessingInfoVisible }: Props = $props();
|
||||
|
||||
let show = $state(false);
|
||||
|
||||
let buttonBottom = $derived(hasProcessingInfoVisible ? '2rem' : '0');
|
||||
|
||||
function checkVisibility() {
|
||||
if (!container) return;
|
||||
const { scrollTop, scrollHeight, clientHeight } = container;
|
||||
const distanceFromBottom = scrollHeight - clientHeight - scrollTop;
|
||||
show = distanceFromBottom > clientHeight * 0.5;
|
||||
}
|
||||
|
||||
function scrollToBottom() {
|
||||
if (container) {
|
||||
container.scrollTo({
|
||||
top: container.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
const c = container;
|
||||
if (c) {
|
||||
c.addEventListener('scroll', checkVisibility);
|
||||
checkVisibility();
|
||||
return () => {
|
||||
c.removeEventListener('scroll', checkVisibility);
|
||||
};
|
||||
}
|
||||
});
|
||||
let { onclick }: { onclick: (e?: MouseEvent) => void } = $props();
|
||||
</script>
|
||||
|
||||
<div class="relative z-50 mx-auto mb-4 flex max-w-[48rem] justify-center">
|
||||
<Button
|
||||
onclick={scrollToBottom}
|
||||
variant="secondary"
|
||||
size="icon"
|
||||
disabled={!show}
|
||||
class="pointer-events-auto absolute h-10 w-10 rounded-full bg-background/80 shadow-lg backdrop-blur-sm transition-all duration-200 hover:bg-muted/80"
|
||||
style="bottom: {buttonBottom}; transform: translateY({show ? '0' : '2rem'}); opacity: {show
|
||||
? 1
|
||||
: 0};"
|
||||
aria-label="Scroll to bottom"
|
||||
>
|
||||
<ArrowDown class="h-4 w-4" />
|
||||
</Button>
|
||||
<div class="pointer-events-auto flex justify-center relative h-0">
|
||||
<ActionIcon
|
||||
icon={ArrowDown}
|
||||
{onclick}
|
||||
ariaLabel="Scroll to bottom"
|
||||
tooltip="Scroll to bottom"
|
||||
size="lg"
|
||||
iconSize="h-4 w-4"
|
||||
class="h-9 w-9 rounded-full bg-accent text-accent-foreground absolute bottom-4 shadow-md"
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
<script lang="ts">
|
||||
import { Trash2 } from '@lucide/svelte';
|
||||
import { ErrorDialogType } from '$lib/enums';
|
||||
import {
|
||||
DialogChatError,
|
||||
DialogConfirmation,
|
||||
DialogEmptyFileAlert,
|
||||
DialogFileUploadError
|
||||
} from '$lib/components/app';
|
||||
|
||||
let {
|
||||
showDeleteDialog,
|
||||
handleDeleteConfirm,
|
||||
showEmptyFileDialog,
|
||||
emptyFileNames,
|
||||
activeErrorDialog,
|
||||
handleErrorDialogOpenChange,
|
||||
fileUpload
|
||||
} = $props();
|
||||
</script>
|
||||
|
||||
<DialogFileUploadError
|
||||
bind:open={fileUpload.showFileErrorDialog}
|
||||
fileErrorData={fileUpload.fileErrorData}
|
||||
/>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description="Are you sure you want to delete this conversation? This action cannot be undone and will permanently remove all messages in this conversation."
|
||||
confirmText="Delete"
|
||||
cancelText="Cancel"
|
||||
variant="destructive"
|
||||
icon={Trash2}
|
||||
onConfirm={handleDeleteConfirm}
|
||||
onCancel={() => (showDeleteDialog = false)}
|
||||
/>
|
||||
|
||||
<DialogEmptyFileAlert
|
||||
bind:open={showEmptyFileDialog}
|
||||
emptyFiles={emptyFileNames}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
emptyFileNames = [];
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<DialogChatError
|
||||
message={activeErrorDialog?.message ?? ''}
|
||||
contextInfo={activeErrorDialog?.contextInfo}
|
||||
onOpenChange={handleErrorDialogOpenChange}
|
||||
open={Boolean(activeErrorDialog)}
|
||||
type={activeErrorDialog?.type ?? ErrorDialogType.SERVER}
|
||||
/>
|
||||
@@ -2,6 +2,7 @@
|
||||
import { afterNavigate } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { ChatForm } from '$lib/components/app';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { onMount } from 'svelte';
|
||||
import { useDraftMessages } from '$lib/hooks/use-draft-messages.svelte';
|
||||
|
||||
@@ -32,7 +33,30 @@
|
||||
}: Props = $props();
|
||||
|
||||
let chatFormRef: ChatForm | undefined = $state(undefined);
|
||||
let formWrapperEl: HTMLDivElement | undefined = $state();
|
||||
let chatId = $derived(page.params.id as string | undefined);
|
||||
|
||||
$effect(() => {
|
||||
if (!formWrapperEl) return;
|
||||
|
||||
const formEl = formWrapperEl.querySelector('form') as HTMLElement | null;
|
||||
if (!formEl) return;
|
||||
|
||||
const updateHeight = () => {
|
||||
const height = Math.round(formEl.getBoundingClientRect().height);
|
||||
document.documentElement.style.setProperty('--chat-form-height', `${height}px`);
|
||||
};
|
||||
|
||||
updateHeight();
|
||||
|
||||
const resizeObserver = new ResizeObserver(updateHeight);
|
||||
resizeObserver.observe(formEl);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
document.documentElement.style.removeProperty('--chat-form-height');
|
||||
};
|
||||
});
|
||||
let hasLoadingAttachments = $derived(uploadedFiles.some((f) => f.isLoading));
|
||||
let message = $derived(initialMessage);
|
||||
let previousIsLoading = $derived(isLoading);
|
||||
@@ -83,12 +107,14 @@
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
setTimeout(() => chatFormRef?.focus(), 10);
|
||||
if (!isMobile.current) {
|
||||
setTimeout(() => chatFormRef?.focus(), 100);
|
||||
}
|
||||
});
|
||||
|
||||
afterNavigate((navigation) => {
|
||||
if (navigation?.from != null) {
|
||||
setTimeout(() => chatFormRef?.focus(), 10);
|
||||
if (navigation?.from != null && !isMobile.current) {
|
||||
setTimeout(() => chatFormRef?.focus(), 100);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -108,12 +134,12 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="relative mx-auto max-w-[48rem]">
|
||||
<div class="chat-screen-form-wrapper" bind:this={formWrapperEl}>
|
||||
<ChatForm
|
||||
class="mx-auto max-w-3xl {className}"
|
||||
bind:this={chatFormRef}
|
||||
bind:value={message}
|
||||
bind:uploadedFiles
|
||||
class={className}
|
||||
{disabled}
|
||||
{isLoading}
|
||||
showMcpPromptButton
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
<script lang="ts">
|
||||
import { fadeInView } from '$lib/actions/fade-in-view.svelte';
|
||||
import { serverStore } from '$lib/stores/server.svelte';
|
||||
|
||||
interface Props {
|
||||
@@ -11,10 +10,9 @@
|
||||
|
||||
<div
|
||||
class={[
|
||||
'pointer-events-none mb-4 hidden px-4 text-center',
|
||||
isEmpty && 'pointer-events-auto block!'
|
||||
'pointer-events-none mb-4 hidden px-4 text-center text-balance',
|
||||
isEmpty && 'mb-[calc(50dvh-8rem)] md:mb-6 pointer-events-auto block!'
|
||||
]}
|
||||
use:fadeInView={{ duration: 300 }}
|
||||
>
|
||||
<h1 class="mb-2 text-2xl font-semibold tracking-tight md:text-3xl">Hello there</h1>
|
||||
|
||||
|
||||
@@ -5,13 +5,8 @@
|
||||
import { chatStore, isLoading, isChatStreaming } from '$lib/stores/chat.svelte';
|
||||
import { activeMessages, activeConversation } from '$lib/stores/conversations.svelte';
|
||||
import { config } from '$lib/stores/settings.svelte';
|
||||
import { getProcessingInfoContext } from '$lib/contexts';
|
||||
import { page } from '$app/state';
|
||||
|
||||
const processingState = useProcessingState();
|
||||
const processingInfoCtx = getProcessingInfoContext();
|
||||
|
||||
let showProcessingInfo = $derived(processingInfoCtx.showProcessingInfo);
|
||||
|
||||
let isCurrentConversationLoading = $derived(isLoading());
|
||||
let isStreaming = $derived(isChatStreaming());
|
||||
@@ -70,8 +65,8 @@
|
||||
|
||||
<div
|
||||
class={[
|
||||
'chat-processing-info-container pointer-events-none relative',
|
||||
page.params.id && showProcessingInfo && 'visible'
|
||||
'chat-processing-info-container pointer-events-none relative w-full hidden md:block',
|
||||
processingVisible && 'visible'
|
||||
]}
|
||||
>
|
||||
<div class="chat-processing-info-content absolute bottom-4 left-1/2 -translate-x-1/2">
|
||||
|
||||
@@ -677,13 +677,6 @@ export { default as ChatScreenForm } from './ChatScreen/ChatScreenForm.svelte';
|
||||
*/
|
||||
export { default as ChatScreenProcessingInfo } from './ChatScreen/ChatScreenProcessingInfo.svelte';
|
||||
|
||||
/**
|
||||
* Scroll-to-bottom action button. Displays a floating button when the user
|
||||
* has scrolled up more than half a viewport height from the bottom.
|
||||
* Takes the chat container element as a prop to manage scroll state internally.
|
||||
*/
|
||||
export { default as ChatScreenActionScrollDown } from './ChatScreen/ChatScreenActionScrollDown.svelte';
|
||||
|
||||
/**
|
||||
* Server error alert displayed when the server is unreachable.
|
||||
* Shows the error message with a retry button.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import { Search, X } from '@lucide/svelte';
|
||||
|
||||
interface Props {
|
||||
autofocus?: boolean;
|
||||
value?: string;
|
||||
placeholder?: string;
|
||||
onInput?: (value: string) => void;
|
||||
@@ -15,6 +16,7 @@
|
||||
}
|
||||
|
||||
let {
|
||||
autofocus,
|
||||
value = $bindable(''),
|
||||
placeholder = 'Search...',
|
||||
onInput,
|
||||
@@ -39,7 +41,7 @@
|
||||
if (value) {
|
||||
value = '';
|
||||
onInput?.('');
|
||||
ref?.focus();
|
||||
ref?.focus({ preventScroll: true });
|
||||
} else {
|
||||
onClose?.();
|
||||
}
|
||||
@@ -52,6 +54,7 @@
|
||||
/>
|
||||
|
||||
<Input
|
||||
{autofocus}
|
||||
{id}
|
||||
bind:value
|
||||
bind:ref
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
<script>
|
||||
import logoMark from '$lib/assets/logo.svg?raw';
|
||||
let { class: className = '', style = '' } = $props();
|
||||
</script>
|
||||
|
||||
<div class={className} {style}>
|
||||
{@html logoMark}
|
||||
</div>
|
||||
|
||||
<style>
|
||||
div :global(svg) {
|
||||
width: var(--size, 1rem);
|
||||
height: var(--size, 1rem);
|
||||
}
|
||||
</style>
|
||||
@@ -51,3 +51,11 @@ export { default as KeyboardShortcutInfo } from './KeyboardShortcutInfo.svelte';
|
||||
* Preview button is shown only for HTML code blocks.
|
||||
*/
|
||||
export { default as CodeBlockActions } from './CodeBlockActions.svelte';
|
||||
|
||||
/**
|
||||
* **Logo** - Application brand mark
|
||||
*
|
||||
* Inline SVG of the application logo. Accepts styling via the standard
|
||||
* `class` and `style` props and inherits color via `currentColor`.
|
||||
*/
|
||||
export { default as Logo } from './Logo.svelte';
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
<script lang="ts">
|
||||
let { percent }: { percent: number } = $props();
|
||||
</script>
|
||||
|
||||
<!-- thin determinate load bar pinned to the bottom edge, pulsing while it fills -->
|
||||
<div class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm">
|
||||
<div
|
||||
class="h-full animate-pulse bg-primary transition-[width] duration-200 ease-out"
|
||||
style="width: {percent}%"
|
||||
></div>
|
||||
</div>
|
||||
@@ -2,8 +2,10 @@
|
||||
import { ChevronDown, Loader2, Package } from '@lucide/svelte';
|
||||
import * as DropdownMenu from '$lib/components/ui/dropdown-menu';
|
||||
import * as Tooltip from '$lib/components/ui/tooltip';
|
||||
import { KeyboardKey } from '$lib/enums';
|
||||
import { KeyboardKey, ServerModelStatus } from '$lib/enums';
|
||||
import { useModelsSelector } from '$lib/hooks/use-models-selector.svelte';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
import { modelLoadFraction } from '$lib/utils';
|
||||
import {
|
||||
DialogModelInformation,
|
||||
DropdownMenuSearchable,
|
||||
@@ -11,6 +13,7 @@
|
||||
ModelsSelectorList,
|
||||
ModelsSelectorOption
|
||||
} from '$lib/components/app';
|
||||
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
|
||||
import type { ModelItem } from './utils';
|
||||
|
||||
interface Props {
|
||||
@@ -113,6 +116,17 @@
|
||||
{/if}
|
||||
{:else}
|
||||
{@const selectedOption = ms.getDisplayOption()}
|
||||
{@const triggerModel = selectedOption?.model}
|
||||
{@const triggerStatus = triggerModel
|
||||
? routerModels().find((m) => m.id === triggerModel)?.status?.value
|
||||
: undefined}
|
||||
{@const triggerLoading =
|
||||
!!triggerModel &&
|
||||
(triggerStatus === ServerModelStatus.LOADING ||
|
||||
modelsStore.isModelOperationInProgress(triggerModel))}
|
||||
{@const triggerLoadPercent = triggerLoading
|
||||
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
|
||||
: 0}
|
||||
|
||||
{#if ms.isRouter}
|
||||
<DropdownMenu.Root bind:open={isOpen} onOpenChange={ms.handleOpenChange}>
|
||||
@@ -123,7 +137,7 @@
|
||||
<DropdownMenu.Trigger
|
||||
{...props}
|
||||
class={[
|
||||
`inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
`relative inline-grid cursor-pointer grid-cols-[1fr_auto_1fr] items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
!ms.isCurrentModelInCache
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
@@ -154,6 +168,10 @@
|
||||
{:else}
|
||||
<ChevronDown class="h-3 w-3.5 shrink-0" />
|
||||
{/if}
|
||||
|
||||
{#if triggerLoading}
|
||||
<ModelLoadHighlight percent={triggerLoadPercent} />
|
||||
{/if}
|
||||
</DropdownMenu.Trigger>
|
||||
{/snippet}
|
||||
</Tooltip.Trigger>
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
RotateCw
|
||||
} from '@lucide/svelte';
|
||||
import { ActionIcon, ModelId } from '$lib/components/app';
|
||||
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
|
||||
import type { ModelOption } from '$lib/types/models';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
@@ -119,11 +120,11 @@
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-5 items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-5">
|
||||
<Loader2 class="h-4 w-4 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
{:else if isFailed}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<CircleAlert
|
||||
class="h-3.5 w-3.5 text-red-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
/>
|
||||
@@ -140,7 +141,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else if isSleeping}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-orange-400 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
@@ -159,7 +160,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else if isLoaded}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-green-500 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
@@ -176,7 +177,7 @@
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="flex w-4 [@media(pointer:coarse)]:w-auto items-center justify-center">
|
||||
<div class="flex w-4 items-center justify-center [@media(pointer:coarse)]:w-auto">
|
||||
<span
|
||||
class="h-2 w-2 rounded-full bg-muted-foreground/50 group-hover:hidden [@media(pointer:coarse)]:hidden"
|
||||
></span>
|
||||
@@ -196,13 +197,6 @@
|
||||
</div>
|
||||
|
||||
{#if isLoading}
|
||||
<div
|
||||
class="pointer-events-none absolute inset-x-0 bottom-0 h-0.5 overflow-hidden rounded-b-sm bg-muted"
|
||||
>
|
||||
<div
|
||||
class="h-full bg-primary transition-[width] duration-200 ease-out"
|
||||
style="width: {loadPercent}%"
|
||||
></div>
|
||||
</div>
|
||||
<ModelLoadHighlight percent={loadPercent} />
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
@@ -8,6 +8,10 @@
|
||||
ModelsSelectorList,
|
||||
SearchInput
|
||||
} from '$lib/components/app';
|
||||
import ModelLoadHighlight from './ModelLoadHighlight.svelte';
|
||||
import { ServerModelStatus } from '$lib/enums';
|
||||
import { modelsStore, routerModels } from '$lib/stores/models.svelte';
|
||||
import { modelLoadFraction } from '$lib/utils';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -61,12 +65,23 @@
|
||||
<p class="text-xs text-muted-foreground">No models available.</p>
|
||||
{:else}
|
||||
{@const selectedOption = ms.getDisplayOption()}
|
||||
{@const triggerModel = selectedOption?.model}
|
||||
{@const triggerStatus = triggerModel
|
||||
? routerModels().find((m) => m.id === triggerModel)?.status?.value
|
||||
: undefined}
|
||||
{@const triggerLoading =
|
||||
!!triggerModel &&
|
||||
(triggerStatus === ServerModelStatus.LOADING ||
|
||||
modelsStore.isModelOperationInProgress(triggerModel))}
|
||||
{@const triggerLoadPercent = triggerLoading
|
||||
? Math.round(modelLoadFraction(modelsStore.getLoadProgress(triggerModel)) * 100)
|
||||
: 0}
|
||||
|
||||
{#if ms.isRouter}
|
||||
<button
|
||||
type="button"
|
||||
class={[
|
||||
`inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 max-sm:px-3 max-sm:py-2 text-xs max-sm:text-sm shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
`relative inline-flex cursor-pointer items-center gap-1.5 rounded-sm bg-background px-1.5 py-1 text-xs shadow-sm transition hover:bg-muted-foreground/20 focus:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-60 max-sm:px-3 max-sm:py-2 max-sm:text-sm dark:bg-muted-foreground/15 dark:text-secondary-foreground`,
|
||||
!ms.isCurrentModelInCache
|
||||
? 'bg-red-400/10 !text-red-400 hover:bg-red-400/20 hover:text-red-400'
|
||||
: forceForegroundText
|
||||
@@ -99,6 +114,10 @@
|
||||
{:else}
|
||||
<ChevronDown class="h-3 w-3.5 shrink-0" />
|
||||
{/if}
|
||||
|
||||
{#if triggerLoading}
|
||||
<ModelLoadHighlight percent={triggerLoadPercent} />
|
||||
{/if}
|
||||
</button>
|
||||
|
||||
<Sheet.Root bind:open={sheetOpen} onOpenChange={handleSheetOpenChange}>
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { ActionIcon } from '$lib/components/app';
|
||||
import {
|
||||
ICON_STRIP_TRANSITION_DURATION,
|
||||
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
|
||||
SIDEBAR_ACTIONS_ITEMS
|
||||
} from '$lib/constants';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { circIn } from 'svelte/easing';
|
||||
import { onMount } from 'svelte';
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
|
||||
interface Props {
|
||||
sidebarOpen: boolean;
|
||||
onSearchClick: () => void;
|
||||
}
|
||||
|
||||
let { sidebarOpen = false, onSearchClick }: Props = $props();
|
||||
|
||||
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
|
||||
|
||||
let initialized = $state(false);
|
||||
let showIcons = $derived(!sidebarOpen);
|
||||
|
||||
showIcons = false;
|
||||
|
||||
onMount(() => {
|
||||
showIcons = !sidebarOpen;
|
||||
|
||||
setTimeout(() => {
|
||||
initialized = true;
|
||||
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
|
||||
});
|
||||
</script>
|
||||
|
||||
<svelte:window onkeydown={handleKeydown} />
|
||||
|
||||
<div
|
||||
class="hidden shrink-0 transition-[width] duration-200 ease-linear md:block {sidebarOpen
|
||||
? 'w-0'
|
||||
: 'w-[calc(var(--sidebar-width-icon)+1.5rem)]'}"
|
||||
></div>
|
||||
<aside
|
||||
class="fixed top-0 bottom-0 left-0 z-10 hidden w-[calc(var(--sidebar-width-icon)+1.5rem)] flex-col items-center justify-between py-3 transition-opacity duration-200 ease-linear md:flex {sidebarOpen
|
||||
? 'pointer-events-none opacity-0'
|
||||
: 'opacity-100'}"
|
||||
>
|
||||
<div class="mt-12 flex flex-col items-center gap-1">
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
|
||||
{@const onclick = item.route ? () => goto(item.route!) : onSearchClick}
|
||||
{@const isActive = item.activeRouteId
|
||||
? page.route.id === item.activeRouteId
|
||||
: item.activeRoutePrefix
|
||||
? !!page.route.id?.startsWith(item.activeRoutePrefix)
|
||||
: false}
|
||||
{#if showIcons}
|
||||
<div
|
||||
in:fade={{
|
||||
duration: ICON_STRIP_TRANSITION_DURATION,
|
||||
delay: !initialized
|
||||
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
|
||||
: 0,
|
||||
easing: circIn
|
||||
}}
|
||||
>
|
||||
<ActionIcon
|
||||
icon={item.icon}
|
||||
tooltip={item.tooltip}
|
||||
tooltipSide={TooltipSide.RIGHT}
|
||||
size="lg"
|
||||
iconSize="h-4 w-4"
|
||||
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
{onclick}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
</aside>
|
||||
+234
-295
@@ -1,40 +1,67 @@
|
||||
<script lang="ts">
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { Trash2, Pencil, Pin, X } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { DialogConfirmation } from '$lib/components/app';
|
||||
import SidebarNavigationActions from './SidebarNavigationActions.svelte';
|
||||
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
|
||||
import { Checkbox } from '$lib/components/ui/checkbox';
|
||||
import Label from '$lib/components/ui/label/label.svelte';
|
||||
import ScrollArea from '$lib/components/ui/scroll-area/scroll-area.svelte';
|
||||
import * as Sidebar from '$lib/components/ui/sidebar';
|
||||
import Input from '$lib/components/ui/input/input.svelte';
|
||||
import { ROUTES } from '$lib/constants/routes';
|
||||
import { RouterService } from '$lib/services/router.service';
|
||||
import { PanelLeftClose, PanelLeftOpen, X } from '@lucide/svelte';
|
||||
import {
|
||||
conversationsStore,
|
||||
conversations,
|
||||
buildConversationTree
|
||||
} from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { getPreviewText } from '$lib/utils';
|
||||
import { APP_NAME } from '$lib/constants';
|
||||
ActionIcon,
|
||||
Logo,
|
||||
SidebarNavigationConversationList,
|
||||
SidebarNavigationActions
|
||||
} from '$lib/components/app';
|
||||
import { ROUTES } from '$lib/constants';
|
||||
import { fade } from 'svelte/transition';
|
||||
|
||||
const sidebar = Sidebar.useSidebar();
|
||||
import { useKeyboardShortcuts } from '$lib/hooks/use-keyboard-shortcuts.svelte';
|
||||
import { conversationsStore, conversations } from '$lib/stores/conversations.svelte';
|
||||
import { chatStore } from '$lib/stores/chat.svelte';
|
||||
import { RouterService } from '$lib/services/router.service';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
import { device } from '$lib/stores/device.svelte';
|
||||
import { circIn } from 'svelte/easing';
|
||||
|
||||
interface Props {
|
||||
onSearchClick?: () => void;
|
||||
}
|
||||
|
||||
let { onSearchClick = () => {} }: Props = $props();
|
||||
|
||||
const { handleKeydown } = useKeyboardShortcuts({ activateSearchMode: () => onSearchClick() });
|
||||
|
||||
let isExpandedMode = $state(false);
|
||||
let hoveredTooltip = $state<string | null>(null);
|
||||
let logoHovered = $state(false);
|
||||
|
||||
const isStripExpanded = $derived(isExpandedMode || hoveredTooltip !== null);
|
||||
const isOnMobile = $derived(isMobile.current);
|
||||
|
||||
function toggleExpandedMode() {
|
||||
isExpandedMode = !isExpandedMode;
|
||||
if (!isExpandedMode) {
|
||||
hoveredTooltip = null;
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!isExpandedMode) {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
cancelMobileCollapse();
|
||||
}
|
||||
});
|
||||
|
||||
// On mobile the dedicated /search route hides the sidebar (see the aside
|
||||
// render guard below). Collapse it as we enter /search so it doesn't
|
||||
// reappear expanded when the user navigates back via the back button.
|
||||
$effect(() => {
|
||||
if (isMobile.current && page.url.hash.includes(ROUTES.SEARCH)) {
|
||||
isExpandedMode = false;
|
||||
}
|
||||
});
|
||||
|
||||
let currentChatId = $derived(page.params.id);
|
||||
let isSearchModeActive = $state(false);
|
||||
let searchQuery = $state('');
|
||||
let showDeleteDialog = $state(false);
|
||||
let deleteWithForks = $state(false);
|
||||
let showEditDialog = $state(false);
|
||||
let selectedConversation = $state<DatabaseConversation | null>(null);
|
||||
let editedName = $state('');
|
||||
let selectedConversationNamePreview = $derived.by(() =>
|
||||
selectedConversation ? getPreviewText(selectedConversation.name) : ''
|
||||
);
|
||||
|
||||
let filteredConversations = $derived.by(() => {
|
||||
if (isSearchModeActive) {
|
||||
@@ -50,294 +77,206 @@
|
||||
return conversations();
|
||||
});
|
||||
|
||||
let conversationTree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
let pinnedConversations = $derived.by(() => {
|
||||
return conversationTree.filter(({ conversation }) => conversation.pinned);
|
||||
});
|
||||
|
||||
let unpinnedConversations = $derived.by(() => {
|
||||
return conversationTree.filter(({ conversation }) => !conversation.pinned);
|
||||
});
|
||||
|
||||
let selectedConversationHasDescendants = $derived.by(() => {
|
||||
if (!selectedConversation) return false;
|
||||
|
||||
const allConvs = conversations();
|
||||
const queue = [selectedConversation.id];
|
||||
|
||||
while (queue.length > 0) {
|
||||
const parentId = queue.pop()!;
|
||||
|
||||
for (const c of allConvs) {
|
||||
if (c.forkedFromConversationId === parentId) return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
async function handleDeleteConversation(id: string) {
|
||||
const conversation = conversations().find((conv) => conv.id === id);
|
||||
if (conversation) {
|
||||
selectedConversation = conversation;
|
||||
deleteWithForks = false;
|
||||
showDeleteDialog = true;
|
||||
async function selectConversation(id: string) {
|
||||
if (isMobile.current) {
|
||||
scheduleMobileCollapse();
|
||||
}
|
||||
await goto(RouterService.chat(id));
|
||||
}
|
||||
|
||||
async function handleEditConversation(id: string) {
|
||||
const conversation = conversations().find((conv) => conv.id === id);
|
||||
if (conversation) {
|
||||
selectedConversation = conversation;
|
||||
editedName = conversation.name;
|
||||
showEditDialog = true;
|
||||
if (!conversation) return;
|
||||
|
||||
const newName = window.prompt('Rename conversation', conversation.name);
|
||||
if (newName && newName.trim()) {
|
||||
await conversationsStore.updateConversationName(id, newName.trim());
|
||||
}
|
||||
}
|
||||
|
||||
function handleConfirmDelete() {
|
||||
if (selectedConversation) {
|
||||
const convId = selectedConversation.id;
|
||||
const withForks = deleteWithForks;
|
||||
showDeleteDialog = false;
|
||||
async function handleDeleteConversation(id: string) {
|
||||
const conversation = conversations().find((conv) => conv.id === id);
|
||||
if (!conversation) return;
|
||||
|
||||
setTimeout(() => {
|
||||
conversationsStore.deleteConversation(convId, {
|
||||
deleteWithForks: withForks
|
||||
});
|
||||
}, 100); // Wait for animation to finish
|
||||
}
|
||||
}
|
||||
const confirmed = window.confirm(
|
||||
`Delete "${conversation.name}"? This action cannot be undone.`
|
||||
);
|
||||
if (!confirmed) return;
|
||||
|
||||
function handleConfirmEdit() {
|
||||
if (!editedName.trim() || !selectedConversation) return;
|
||||
|
||||
showEditDialog = false;
|
||||
|
||||
conversationsStore.updateConversationName(selectedConversation.id, editedName);
|
||||
selectedConversation = null;
|
||||
}
|
||||
|
||||
export function handleMobileSidebarItemClick() {
|
||||
if (sidebar.isMobile) {
|
||||
sidebar.toggle();
|
||||
}
|
||||
}
|
||||
|
||||
let chatSidebarActions: { activateSearch?: () => void } | undefined = $state();
|
||||
let openedForSearch = $state(false);
|
||||
|
||||
export function activateSearchMode() {
|
||||
if (!sidebar.open) {
|
||||
openedForSearch = true;
|
||||
}
|
||||
chatSidebarActions?.activateSearch?.();
|
||||
}
|
||||
|
||||
function handleSearchDeactivated() {
|
||||
if (openedForSearch) {
|
||||
openedForSearch = false;
|
||||
sidebar.toggle();
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (!sidebar.open) {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
openedForSearch = false;
|
||||
}
|
||||
});
|
||||
|
||||
export function editActiveConversation() {
|
||||
if (currentChatId) {
|
||||
const activeConversation = filteredConversations.find((conv) => conv.id === currentChatId);
|
||||
|
||||
if (activeConversation) {
|
||||
const event = new CustomEvent('edit-active-conversation', {
|
||||
detail: { conversationId: currentChatId }
|
||||
});
|
||||
document.dispatchEvent(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function selectConversation(id: string) {
|
||||
if (isSearchModeActive) {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
}
|
||||
|
||||
handleMobileSidebarItemClick();
|
||||
await goto(RouterService.chat(id));
|
||||
await conversationsStore.deleteConversation(id, { deleteWithForks: false });
|
||||
}
|
||||
|
||||
function handleStopGeneration(id: string) {
|
||||
chatStore.stopGenerationForChat(id);
|
||||
}
|
||||
|
||||
let innerWidth = $state(0);
|
||||
let pendingCollapse = $state<ReturnType<typeof setTimeout> | null>(null);
|
||||
|
||||
function scheduleMobileCollapse() {
|
||||
if (pendingCollapse) {
|
||||
clearTimeout(pendingCollapse);
|
||||
}
|
||||
pendingCollapse = setTimeout(() => {
|
||||
isExpandedMode = false;
|
||||
pendingCollapse = null;
|
||||
}, 100);
|
||||
}
|
||||
|
||||
function cancelMobileCollapse() {
|
||||
if (pendingCollapse) {
|
||||
clearTimeout(pendingCollapse);
|
||||
pendingCollapse = null;
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="flex h-full flex-col">
|
||||
<ScrollArea class="h-full flex-1">
|
||||
<Sidebar.Header class="gap-4 bg-sidebar/50 p-3 backdrop-blur-lg md:pt-4 md:pb-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<a href={ROUTES.START} onclick={handleMobileSidebarItemClick}>
|
||||
<h1 class="inline-flex items-center gap-1 px-2 text-xl font-semibold">
|
||||
{APP_NAME}
|
||||
</h1>
|
||||
</a>
|
||||
<svelte:window onkeydown={handleKeydown} bind:innerWidth />
|
||||
|
||||
<Button
|
||||
class="rounded-full md:hidden"
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onclick={() => sidebar.toggle()}
|
||||
>
|
||||
<X class="h-4 w-4" />
|
||||
<span class="sr-only">Close sidebar</span>
|
||||
</Button>
|
||||
{#if innerWidth > 768 || (!page.url.hash.includes(ROUTES.SETTINGS) && !page.url.hash.includes(ROUTES.MCP_SERVERS) && !page.url.hash.includes(ROUTES.SEARCH))}
|
||||
<aside
|
||||
class={[
|
||||
// Layout & positioning
|
||||
'fixed md:sticky top-2 left-2 md:left-0 md:ml-2 md:mt-2 pt-2 z-10 w-[calc(100dvw-1rem)]',
|
||||
// Dimensions & overflow
|
||||
'md:h-[calc(100dvh-1.125rem)]',
|
||||
isExpandedMode &&
|
||||
(device.isStandalone
|
||||
? 'h-[calc(100dvh-2rem)]'
|
||||
: device.isIOSDevice
|
||||
? 'h-[calc(100dvh-0.5rem)]'
|
||||
: 'h-[calc(100dvh-1rem)]'),
|
||||
// Shape & depth
|
||||
'rounded-3xl md:rounded-2xl',
|
||||
// Flex layout
|
||||
'flex flex-col justify-between',
|
||||
// Transition
|
||||
'md:transition-[width,padding] duration-200 ease-out',
|
||||
// Expanded state: width, surface, depth
|
||||
isStripExpanded && 'md:w-72 md:bg-muted/60 md:backdrop-blur-xl border-border shadow-md',
|
||||
// Collapsed state
|
||||
!isStripExpanded && 'md:w-12',
|
||||
// Expanded mode flag (for mobile ::before overlay)
|
||||
isExpandedMode && 'is-expanded'
|
||||
]}
|
||||
>
|
||||
<div class="px-2 flex items-center justify-between">
|
||||
<div
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="relative"
|
||||
onmouseenter={() => (logoHovered = true)}
|
||||
onmouseleave={() => (logoHovered = false)}
|
||||
>
|
||||
<ActionIcon
|
||||
icon={!isExpandedMode && logoHovered && innerWidth > 768 ? PanelLeftOpen : Logo}
|
||||
size="lg"
|
||||
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
|
||||
class="{isExpandedMode
|
||||
? 'bg-muted! md:bg-foreground/5!'
|
||||
: 'bg-transparent!'} md:h-9 md:w-9 h-10 w-10 rounded-full md:hover:bg-foreground/10! pointer-events-auto"
|
||||
href={isExpandedMode ? ROUTES.START : undefined}
|
||||
onclick={isExpandedMode ? undefined : toggleExpandedMode}
|
||||
tooltip={isExpandedMode ? undefined : 'Open Sidebar'}
|
||||
tooltipSide={TooltipSide.RIGHT}
|
||||
ariaLabel={isExpandedMode ? 'Go to start' : 'Expand navigation'}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<SidebarNavigationActions
|
||||
bind:this={chatSidebarActions}
|
||||
{handleMobileSidebarItemClick}
|
||||
bind:isSearchModeActive
|
||||
bind:searchQuery
|
||||
onSearchDeactivated={handleSearchDeactivated}
|
||||
/>
|
||||
</Sidebar.Header>
|
||||
|
||||
{#if !isSearchModeActive && pinnedConversations.length > 0}
|
||||
<Sidebar.Group class="p-0 px-4">
|
||||
<Sidebar.GroupLabel>
|
||||
<div class="flex items-center gap-1">
|
||||
<Pin class="h-3.5 w-3.5" />
|
||||
<span>Pinned</span>
|
||||
</div>
|
||||
</Sidebar.GroupLabel>
|
||||
<Sidebar.GroupContent>
|
||||
<Sidebar.Menu>
|
||||
{#each pinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
</Sidebar.Menu>
|
||||
</Sidebar.GroupContent>
|
||||
</Sidebar.Group>
|
||||
{/if}
|
||||
|
||||
<Sidebar.Group class="mt-2 h-[calc(100vh-21rem)] space-y-2 p-0 px-3">
|
||||
{#if (filteredConversations.length > 0 && isSearchModeActive) || !isSearchModeActive}
|
||||
<Sidebar.GroupLabel>
|
||||
{isSearchModeActive ? 'Search results' : 'Recent conversations'}
|
||||
</Sidebar.GroupLabel>
|
||||
{#if isExpandedMode || isOnMobile}
|
||||
<div
|
||||
class="flex items-center transition-all duration-150 ease-out {isMobile.current &&
|
||||
!isExpandedMode
|
||||
? 'opacity-0 h-0!'
|
||||
: ''}"
|
||||
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
|
||||
out:fade={{ duration: 100 }}
|
||||
>
|
||||
<ActionIcon
|
||||
icon={isMobile.current ? X : PanelLeftClose}
|
||||
size="lg"
|
||||
iconSize="h-4.5 w-4.5 md:h-4 md:w-4"
|
||||
class="backdrop-blur-none md:h-9 md:w-9 h-10 w-10 rounded-full mr-1 hover:bg-accent!"
|
||||
onclick={toggleExpandedMode}
|
||||
tooltip="Close Sidebar"
|
||||
tooltipSide={TooltipSide.LEFT}
|
||||
ariaLabel="Collapse navigation"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<Sidebar.GroupContent>
|
||||
<Sidebar.Menu>
|
||||
{#each isSearchModeActive ? conversationTree : unpinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<Sidebar.MenuItem class="mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
</Sidebar.MenuItem>
|
||||
{/each}
|
||||
|
||||
{#if (isSearchModeActive ? conversationTree : unpinnedConversations).length === 0}
|
||||
<div class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{searchQuery.length > 0
|
||||
? 'No results found'
|
||||
: isSearchModeActive
|
||||
? 'Start typing to see results'
|
||||
: 'No conversations yet'}
|
||||
</p>
|
||||
</div>
|
||||
{/if}
|
||||
</Sidebar.Menu>
|
||||
</Sidebar.GroupContent>
|
||||
</Sidebar.Group>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showDeleteDialog}
|
||||
title="Delete Conversation"
|
||||
description={selectedConversation
|
||||
? `Are you sure you want to delete "${selectedConversationNamePreview}"? This action cannot be undone and will permanently remove all messages in this conversation.`
|
||||
: ''}
|
||||
confirmText="Delete"
|
||||
cancelText="Cancel"
|
||||
variant="destructive"
|
||||
icon={Trash2}
|
||||
onConfirm={handleConfirmDelete}
|
||||
onCancel={() => {
|
||||
showDeleteDialog = false;
|
||||
selectedConversation = null;
|
||||
}}
|
||||
>
|
||||
{#if selectedConversationHasDescendants}
|
||||
<div class="flex items-center gap-2 py-2">
|
||||
<Checkbox id="delete-with-forks" bind:checked={deleteWithForks} />
|
||||
|
||||
<Label for="delete-with-forks" class="text-sm">Also delete all forked conversations</Label>
|
||||
</div>
|
||||
{/if}
|
||||
</DialogConfirmation>
|
||||
|
||||
<DialogConfirmation
|
||||
bind:open={showEditDialog}
|
||||
title="Edit Conversation Name"
|
||||
description=""
|
||||
confirmText="Save"
|
||||
cancelText="Cancel"
|
||||
icon={Pencil}
|
||||
onConfirm={handleConfirmEdit}
|
||||
onCancel={() => {
|
||||
showEditDialog = false;
|
||||
selectedConversation = null;
|
||||
}}
|
||||
onKeydown={(event) => {
|
||||
if (event.key === 'Enter') {
|
||||
event.preventDefault();
|
||||
event.stopImmediatePropagation();
|
||||
handleConfirmEdit();
|
||||
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-1 overflow-y-auto">
|
||||
<div
|
||||
class="flex min-h-0 flex-1 flex-col gap-4 md:gap-1 {isMobile.current
|
||||
? 'transition-[opacity,height] duration-200 ease-out'
|
||||
: ''} {isMobile.current && !isExpandedMode ? 'opacity-0 !h-0' : ''}"
|
||||
in:fade={{ duration: 200 }}
|
||||
out:fade={{ duration: 200 }}
|
||||
>
|
||||
<SidebarNavigationActions
|
||||
isExpandedMode={innerWidth > 768 ? isExpandedMode : true}
|
||||
class="px-2"
|
||||
bind:isSearchModeActive
|
||||
bind:searchQuery
|
||||
onSearchDeactivated={() => {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
}}
|
||||
onSearchClick={() => {
|
||||
isExpandedMode = true;
|
||||
isSearchModeActive = true;
|
||||
}}
|
||||
onNewChat={() => {
|
||||
if (isMobile.current) {
|
||||
scheduleMobileCollapse();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
{#if isExpandedMode || isOnMobile}
|
||||
<SidebarNavigationConversationList
|
||||
class="px-2"
|
||||
{filteredConversations}
|
||||
{currentChatId}
|
||||
{isSearchModeActive}
|
||||
{searchQuery}
|
||||
onSelect={selectConversation}
|
||||
onEdit={handleEditConversation}
|
||||
onDelete={handleDeleteConversation}
|
||||
onStop={handleStopGeneration}
|
||||
/>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</aside>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
aside {
|
||||
@media (max-width: 768px) {
|
||||
--size: 1.125rem;
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Input
|
||||
class="text-foreground"
|
||||
placeholder="Enter a new name"
|
||||
type="text"
|
||||
bind:value={editedName}
|
||||
/>
|
||||
</DialogConfirmation>
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
aside {
|
||||
&:not(.is-expanded) {
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
|
||||
aside.is-expanded::before {
|
||||
content: '';
|
||||
position: fixed;
|
||||
top: -0.5rem;
|
||||
bottom: -0.25rem;
|
||||
left: -0.5rem;
|
||||
right: -0.5rem;
|
||||
z-index: -1;
|
||||
background: var(--background);
|
||||
backdrop-filter: blur(1rem);
|
||||
pointer-events: none;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
||||
+157
-57
@@ -1,39 +1,86 @@
|
||||
<script lang="ts">
|
||||
import { KeyboardShortcutInfo } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import type { Component } from 'svelte';
|
||||
import { SearchInput } from '$lib/components/app';
|
||||
import { goto } from '$app/navigation';
|
||||
import { page } from '$app/state';
|
||||
import { SIDEBAR_ACTIONS_ITEMS } from '$lib/constants/ui';
|
||||
import { Search } from '@lucide/svelte';
|
||||
import { ActionIcon, KeyboardShortcutInfo, SearchInput } from '$lib/components/app';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import {
|
||||
ICON_STRIP_TRANSITION_DURATION,
|
||||
ICON_STRIP_TRANSITION_DELAY_MULTIPLIER,
|
||||
ROUTES,
|
||||
SIDEBAR_ACTIONS_ITEMS
|
||||
} from '$lib/constants';
|
||||
import { isMobile } from '$lib/stores/viewport.svelte';
|
||||
import { TooltipSide } from '$lib/enums';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { circIn } from 'svelte/easing';
|
||||
import { onMount } from 'svelte';
|
||||
import type { Component } from 'svelte';
|
||||
|
||||
interface Props {
|
||||
handleMobileSidebarItemClick: () => void;
|
||||
class: string;
|
||||
isExpandedMode: boolean;
|
||||
isSearchModeActive: boolean;
|
||||
searchQuery: string;
|
||||
isCancelAlwaysVisible?: boolean;
|
||||
onSearchDeactivated?: () => void;
|
||||
onSearchClick?: () => void;
|
||||
onNewChat?: () => void;
|
||||
}
|
||||
|
||||
let {
|
||||
handleMobileSidebarItemClick,
|
||||
isSearchModeActive = $bindable(),
|
||||
searchQuery = $bindable(),
|
||||
isCancelAlwaysVisible = false,
|
||||
onSearchDeactivated
|
||||
class: className,
|
||||
isExpandedMode = false,
|
||||
isSearchModeActive = $bindable(false),
|
||||
searchQuery = $bindable(''),
|
||||
onSearchDeactivated,
|
||||
onSearchClick,
|
||||
onNewChat
|
||||
}: Props = $props();
|
||||
|
||||
let initialized = $state(false);
|
||||
let showIcons = $state(false);
|
||||
let searchInputRef = $state<HTMLInputElement | null>(null);
|
||||
|
||||
const isOnMobile = $derived(isMobile.current);
|
||||
|
||||
$effect(() => {
|
||||
if (isSearchModeActive && searchInputRef) {
|
||||
searchInputRef.focus();
|
||||
}
|
||||
});
|
||||
|
||||
onMount(() => {
|
||||
showIcons = true;
|
||||
|
||||
setTimeout(() => {
|
||||
initialized = true;
|
||||
}, ICON_STRIP_TRANSITION_DELAY_MULTIPLIER * SIDEBAR_ACTIONS_ITEMS.length);
|
||||
});
|
||||
|
||||
function handleSearchModeDeactivate() {
|
||||
isSearchModeActive = false;
|
||||
searchQuery = '';
|
||||
onSearchDeactivated?.();
|
||||
}
|
||||
|
||||
export function activateSearch() {
|
||||
isSearchModeActive = true;
|
||||
// Focus after Svelte renders the input
|
||||
queueMicrotask(() => searchInputRef?.focus());
|
||||
function isItemActive(item: {
|
||||
activeRouteId?: string;
|
||||
activeRoutePrefix?: string;
|
||||
activeUrlIncludes?: string;
|
||||
}): boolean {
|
||||
if (item.activeRouteId) {
|
||||
return page.route.id === item.activeRouteId;
|
||||
}
|
||||
|
||||
if (item.activeRoutePrefix) {
|
||||
return !!page.route.id?.startsWith(item.activeRoutePrefix);
|
||||
}
|
||||
|
||||
if (item.activeUrlIncludes) {
|
||||
return page.url?.hash?.includes(item.activeUrlIncludes) ?? false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -41,56 +88,109 @@
|
||||
<IconComponent class="h-4 w-4" />
|
||||
{/snippet}
|
||||
|
||||
<div class="my-1 space-y-1">
|
||||
{#if isSearchModeActive}
|
||||
{#if isSearchModeActive}
|
||||
<div class="px-4 my-2">
|
||||
<SearchInput
|
||||
bind:value={searchQuery}
|
||||
bind:ref={searchInputRef}
|
||||
onClose={handleSearchModeDeactivate}
|
||||
onKeyDown={(e) => e.key === 'Escape' && handleSearchModeDeactivate()}
|
||||
placeholder="Search conversations..."
|
||||
{isCancelAlwaysVisible}
|
||||
/>
|
||||
{:else}
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item (item.route)}
|
||||
{#if !item.route}
|
||||
<Button
|
||||
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100"
|
||||
onclick={activateSearch}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
{@render itemIcon(item.icon)}
|
||||
</div>
|
||||
{:else if isExpandedMode || isOnMobile}
|
||||
<div
|
||||
class="{className} flex flex-col gap-5 md:gap-1 mt-2 md:mt-0 {!isExpandedMode && isOnMobile
|
||||
? 'hidden pointer-events-none'
|
||||
: ''}"
|
||||
>
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
|
||||
{@const isActive = isItemActive(item)}
|
||||
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
|
||||
{@const itemHref = isSearchOnMobile ? ROUTES.SEARCH : item.route}
|
||||
{@const itemOnClick = item.route
|
||||
? () => {
|
||||
onNewChat?.();
|
||||
goto(item.route!);
|
||||
}
|
||||
: isSearchOnMobile
|
||||
? undefined
|
||||
: onSearchClick}
|
||||
{@const itemTransition = {
|
||||
duration: ICON_STRIP_TRANSITION_DURATION,
|
||||
delay: !initialized
|
||||
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
|
||||
: 0,
|
||||
easing: circIn
|
||||
}}
|
||||
|
||||
{item.tooltip}
|
||||
</div>
|
||||
{#if showIcons}
|
||||
<div transition:fade={itemTransition}>
|
||||
<Button
|
||||
class="w-full min-w-9 justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {isActive
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
href={itemHref}
|
||||
onclick={itemOnClick}
|
||||
variant="ghost"
|
||||
size="default"
|
||||
>
|
||||
<span class="flex min-w-0 items-center px-0.5 gap-2">
|
||||
{@render itemIcon(item.icon)}
|
||||
|
||||
{#if item.keys}
|
||||
<KeyboardShortcutInfo keys={item.keys} />
|
||||
{/if}
|
||||
</Button>
|
||||
{:else}
|
||||
<Button
|
||||
class="w-full justify-between px-2 backdrop-blur-none! hover:[&>kbd]:opacity-100 {(item.activeRouteId &&
|
||||
page.route.id === item.activeRouteId) ||
|
||||
(item.activeRoutePrefix && page.route.id?.startsWith(item.activeRoutePrefix))
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
href={item.route}
|
||||
onclick={handleMobileSidebarItemClick}
|
||||
variant="ghost"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
{@render itemIcon(item.icon)}
|
||||
{#if showIcons}
|
||||
<span
|
||||
in:fade={{ duration: 150, easing: circIn, delay: 50 }}
|
||||
out:fade={{ duration: 100 }}
|
||||
class="min-w-0 truncate">{item.tooltip}</span
|
||||
>
|
||||
{/if}
|
||||
</span>
|
||||
|
||||
{item.tooltip}
|
||||
</div>
|
||||
|
||||
{#if item.keys}
|
||||
<KeyboardShortcutInfo keys={item.keys} />
|
||||
{/if}
|
||||
</Button>
|
||||
{#if item.keys}
|
||||
<KeyboardShortcutInfo keys={item.keys} />
|
||||
{/if}
|
||||
</Button>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="{className} flex-col gap-1 hidden md:flex">
|
||||
{#each SIDEBAR_ACTIONS_ITEMS as item, i (item.tooltip)}
|
||||
{@const isActive = isItemActive(item)}
|
||||
{@const isSearchOnMobile = item.icon === Search && isMobile.current}
|
||||
{@const itemOnClick = item.route
|
||||
? () => {
|
||||
onNewChat?.();
|
||||
goto(item.route!);
|
||||
}
|
||||
: isSearchOnMobile
|
||||
? undefined
|
||||
: onSearchClick}
|
||||
{@const itemTransition = {
|
||||
duration: ICON_STRIP_TRANSITION_DURATION,
|
||||
delay: !initialized
|
||||
? ICON_STRIP_TRANSITION_DELAY_MULTIPLIER + i * ICON_STRIP_TRANSITION_DELAY_MULTIPLIER
|
||||
: 0,
|
||||
easing: circIn
|
||||
}}
|
||||
|
||||
{#if showIcons}
|
||||
<div transition:fade={itemTransition}>
|
||||
<ActionIcon
|
||||
icon={item.icon}
|
||||
tooltip={item.tooltip}
|
||||
tooltipSide={TooltipSide.RIGHT}
|
||||
size="lg"
|
||||
iconSize="h-4 w-4"
|
||||
class="h-9 w-9 rounded-full hover:bg-accent! {isActive
|
||||
? 'bg-accent text-accent-foreground'
|
||||
: ''}"
|
||||
onclick={itemOnClick}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
+135
@@ -0,0 +1,135 @@
|
||||
<script lang="ts">
|
||||
import { Pin } from '@lucide/svelte';
|
||||
import { buildConversationTree } from '$lib/stores/conversations.svelte';
|
||||
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
|
||||
import SidebarNavigationSearchResults from './SidebarNavigationSearchResults.svelte';
|
||||
|
||||
interface Props {
|
||||
class: string;
|
||||
filteredConversations: DatabaseConversation[];
|
||||
currentChatId: string | undefined;
|
||||
isSearchModeActive: boolean;
|
||||
searchQuery: string;
|
||||
onSelect: (id: string) => void;
|
||||
onEdit: (id: string) => void;
|
||||
onDelete: (id: string) => void;
|
||||
onStop: (id: string) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className,
|
||||
filteredConversations,
|
||||
currentChatId,
|
||||
isSearchModeActive,
|
||||
searchQuery,
|
||||
onSelect,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onStop
|
||||
}: Props = $props();
|
||||
|
||||
let conversationTree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
let pinnedConversations = $derived(
|
||||
conversationTree.filter(({ conversation }) => conversation.pinned)
|
||||
);
|
||||
|
||||
let unpinnedConversations = $derived(
|
||||
conversationTree.filter(({ conversation }) => !conversation.pinned)
|
||||
);
|
||||
|
||||
const recentEmptyMessage = $derived(
|
||||
searchQuery.length > 0 ? 'No results found' : 'No conversations yet'
|
||||
);
|
||||
</script>
|
||||
|
||||
{#if isSearchModeActive}
|
||||
<SidebarNavigationSearchResults
|
||||
class={className}
|
||||
{searchQuery}
|
||||
{filteredConversations}
|
||||
{currentChatId}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
{:else}
|
||||
{#if pinnedConversations.length > 0}
|
||||
<div class="py-2 flex whitespace-nowrap {className}">
|
||||
<div
|
||||
class="text-muted-foreground inline-flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium gap-1"
|
||||
>
|
||||
<Pin class="h-3.5 w-3.5" />
|
||||
|
||||
<span>Pinned</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1 {className}">
|
||||
{#each pinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<li class="group/item relative mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
</li>
|
||||
{/each}
|
||||
</ul>
|
||||
{/if}
|
||||
|
||||
<div class="mt-2 flex min-h-0 flex-1 flex-col gap-4 md:gap-2 whitespace-nowrap {className}">
|
||||
{#if filteredConversations.length > 0}
|
||||
<div
|
||||
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
|
||||
>
|
||||
Recent conversations
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="min-h-0 flex-1 md:overflow-y-auto">
|
||||
<ul class="flex w-full min-w-0 flex-col gap-4 md:gap-1">
|
||||
{#each unpinnedConversations as { conversation, depth } (conversation.id)}
|
||||
<li class="group/item relative mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
</li>
|
||||
{/each}
|
||||
|
||||
{#if unpinnedConversations.length === 0}
|
||||
<li class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{recentEmptyMessage}
|
||||
</p>
|
||||
</li>
|
||||
{/if}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
+3
-1
@@ -16,4 +16,6 @@
|
||||
}: Props = $props();
|
||||
</script>
|
||||
|
||||
<SearchInput bind:value {placeholder} {onInput} class="mb-4 {className}" />
|
||||
<div class="mb-4 px-2 {className}">
|
||||
<SearchInput bind:value {placeholder} {onInput} />
|
||||
</div>
|
||||
|
||||
+76
@@ -0,0 +1,76 @@
|
||||
<script lang="ts">
|
||||
import { buildConversationTree } from '$lib/stores/conversations.svelte';
|
||||
import SidebarNavigationConversationItem from './SidebarNavigationConversationItem.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
searchQuery: string;
|
||||
filteredConversations: DatabaseConversation[];
|
||||
currentChatId: string | undefined;
|
||||
onSelect: (id: string) => void;
|
||||
onEdit: (id: string) => void;
|
||||
onDelete: (id: string) => void;
|
||||
onStop: (id: string) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
class: className = '',
|
||||
searchQuery,
|
||||
filteredConversations,
|
||||
currentChatId,
|
||||
onSelect,
|
||||
onEdit,
|
||||
onDelete,
|
||||
onStop
|
||||
}: Props = $props();
|
||||
|
||||
let tree = $derived(buildConversationTree(filteredConversations));
|
||||
|
||||
const hasQuery = $derived(searchQuery.trim().length > 0);
|
||||
const showHeader = $derived(hasQuery && filteredConversations.length > 0);
|
||||
|
||||
const emptyMessage = $derived(hasQuery ? 'No results found' : 'Start typing to see results');
|
||||
</script>
|
||||
|
||||
<div class="flex min-h-0 flex-1 flex-col gap-2 whitespace-nowrap {className}">
|
||||
{#if showHeader}
|
||||
<div
|
||||
class="text-muted-foreground flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium"
|
||||
>
|
||||
Search results
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="min-h-0 flex-1 overflow-y-auto">
|
||||
<ul class="flex w-full min-w-0 flex-col gap-1">
|
||||
{#each tree as { conversation, depth } (conversation.id)}
|
||||
<li class="group/item relative mb-1 p-0">
|
||||
<SidebarNavigationConversationItem
|
||||
conversation={{
|
||||
id: conversation.id,
|
||||
name: conversation.name,
|
||||
lastModified: conversation.lastModified,
|
||||
currNode: conversation.currNode,
|
||||
forkedFromConversationId: conversation.forkedFromConversationId,
|
||||
pinned: conversation.pinned
|
||||
}}
|
||||
{depth}
|
||||
isActive={currentChatId === conversation.id}
|
||||
{onSelect}
|
||||
{onEdit}
|
||||
{onDelete}
|
||||
{onStop}
|
||||
/>
|
||||
</li>
|
||||
{/each}
|
||||
|
||||
{#if tree.length === 0}
|
||||
<li class="px-2 py-4 text-center">
|
||||
<p class="mb-4 p-4 text-sm text-muted-foreground">
|
||||
{emptyMessage}
|
||||
</p>
|
||||
</li>
|
||||
{/if}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
@@ -63,15 +63,6 @@ export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svel
|
||||
* ```
|
||||
*/
|
||||
export { default as DropdownMenuActions } from './DropdownMenuActions.svelte';
|
||||
|
||||
/**
|
||||
* **DesktopIconStrip** - Fixed icon strip for desktop sidebar
|
||||
*
|
||||
* Vertical icon strip shown on desktop when the sidebar is collapsed.
|
||||
* Contains navigation shortcuts for new chat, search, MCP, import/export, and settings.
|
||||
*/
|
||||
export { default as DesktopIconStrip } from './DesktopIconStrip.svelte';
|
||||
|
||||
/**
|
||||
* **SidebarNavigation** - Sidebar with actions menu and conversation list
|
||||
*
|
||||
@@ -115,13 +106,6 @@ export { default as DesktopIconStrip } from './DesktopIconStrip.svelte';
|
||||
*/
|
||||
export { default as SidebarNavigation } from './SidebarNavigation/SidebarNavigation.svelte';
|
||||
|
||||
/**
|
||||
* Action buttons for sidebar header. Contains new chat button, settings button,
|
||||
* and delete all conversations button. Manages dialog states for settings and
|
||||
* delete confirmation.
|
||||
*/
|
||||
export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte';
|
||||
|
||||
/**
|
||||
* Single conversation item in sidebar. Displays conversation title (truncated),
|
||||
* last message preview, and timestamp. Shows context menu on right-click with
|
||||
@@ -130,6 +114,58 @@ export { default as SidebarNavigationActions } from './SidebarNavigation/Sidebar
|
||||
*/
|
||||
export { default as SidebarNavigationConversationItem } from './SidebarNavigation/SidebarNavigationConversationItem.svelte';
|
||||
|
||||
/**
|
||||
* **SidebarNavigationConversationList** - Grouped conversation list
|
||||
*
|
||||
* Pure-presentational list of conversations. Splits items into a Pinned
|
||||
* section (when not in search mode) and a Recent Conversations / Search
|
||||
* Results section with the unpinned items. Item selection, edit, delete,
|
||||
* and stop-generation are delegated to the caller via callbacks.
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <SidebarNavigationConversationList
|
||||
* {filteredConversations}
|
||||
* {currentChatId}
|
||||
* {isSearchModeActive}
|
||||
* {searchQuery}
|
||||
* onSelect={...}
|
||||
* onEdit={...}
|
||||
* onDelete={...}
|
||||
* onStop={...}
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export { default as SidebarNavigationConversationList } from './SidebarNavigation/SidebarNavigationConversationList.svelte';
|
||||
export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte';
|
||||
|
||||
/**
|
||||
* **SidebarNavigationSearchResults** - Filtered conversation list for search.
|
||||
*
|
||||
* Pure-presentational rendering of the search-mode subtree: "Search results"
|
||||
* header, the matching items rendered through {@link SidebarNavigationConversationItem},
|
||||
* and contextual empty-state messages. Used both inline inside
|
||||
* {@link SidebarNavigationConversationList} (when search mode is active in the
|
||||
* sidebar) and as the body of the mobile `/search` route.
|
||||
*
|
||||
* The caller is expected to provide an already-filtered list via
|
||||
* `filteredConversations` and a `searchQuery` for the empty-state messages.
|
||||
*
|
||||
* @example
|
||||
* ```svelte
|
||||
* <SidebarNavigationSearchResults
|
||||
* {searchQuery}
|
||||
* {filteredConversations}
|
||||
* {currentChatId}
|
||||
* onSelect={...}
|
||||
* onEdit={...}
|
||||
* onDelete={...}
|
||||
* onStop={...}
|
||||
* />
|
||||
* ```
|
||||
*/
|
||||
export { default as SidebarNavigationSearchResults } from './SidebarNavigation/SidebarNavigationSearchResults.svelte';
|
||||
|
||||
/**
|
||||
* Search input for filtering conversations in sidebar. Filters conversation
|
||||
* list by title as user types. Shows clear button when query is not empty.
|
||||
|
||||
@@ -126,10 +126,7 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div
|
||||
class="mx-auto flex h-full max-h-[100dvh] w-full flex-col overflow-y-auto md:pl-8"
|
||||
in:fade={{ duration: 150 }}
|
||||
>
|
||||
<div class="mx-auto flex h-full w-full flex-col md:pl-8" in:fade={{ duration: 150 }}>
|
||||
<div class="flex flex-1 flex-col gap-4 md:flex-row">
|
||||
<SettingsChatDesktopSidebar
|
||||
sections={SETTINGS_CHAT_SECTIONS}
|
||||
|
||||
@@ -12,11 +12,13 @@
|
||||
let { sections, isActive, getHref, onSectionChange }: Props = $props();
|
||||
</script>
|
||||
|
||||
<div class="sticky top-0 hidden w-64 flex-col self-start bg-background pt-10 pb-4 md:flex">
|
||||
<div class="flex items-center gap-2 pb-10">
|
||||
<Settings class="h-6 w-6" />
|
||||
<h1 class="text-2xl font-semibold">Settings</h1>
|
||||
<div class="sticky top-2 hidden w-64 flex-col self-start bg-background py-4 md:flex gap-6">
|
||||
<div class="flex items-center gap-2 py-2">
|
||||
<Settings class="h-5 w-5 md:h-6 md:w-6" />
|
||||
|
||||
<h1 class="text-xl font-semibold md:text-2xl">Settings</h1>
|
||||
</div>
|
||||
|
||||
<nav class="space-y-1">
|
||||
{#each sections as section (section.title)}
|
||||
{#if getHref}
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
<script lang="ts">
|
||||
import { Plus } from '@lucide/svelte';
|
||||
import { X, Plus } from '@lucide/svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { mcpStore } from '$lib/stores/mcp.svelte';
|
||||
import { conversationsStore } from '$lib/stores/conversations.svelte';
|
||||
import { toolsStore } from '$lib/stores/tools.svelte';
|
||||
import { McpServerCard, McpServerCardSkeleton } from '$lib/components/app/mcp';
|
||||
import { ActionIcon, McpServerCard, McpServerCardSkeleton } from '$lib/components/app';
|
||||
import { DialogMcpServerAddNew } from '$lib/components/app/dialogs';
|
||||
import { HealthCheckStatus } from '$lib/enums';
|
||||
import { ROUTES } from '$lib/constants';
|
||||
import { fade } from 'svelte/transition';
|
||||
import { onMount } from 'svelte';
|
||||
import McpLogo from '../mcp/McpLogo.svelte';
|
||||
import { browser } from '$app/environment';
|
||||
import { page } from '$app/state';
|
||||
import { replaceState } from '$app/navigation';
|
||||
import { goto, replaceState } from '$app/navigation';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -24,6 +26,24 @@
|
||||
let initialLoadComplete = $state(false);
|
||||
let isAddingServer = $state(false);
|
||||
|
||||
let previousRouteId = $state<string | null>(null);
|
||||
|
||||
$effect(() => {
|
||||
const currentId = page.route.id;
|
||||
return () => {
|
||||
previousRouteId = currentId;
|
||||
};
|
||||
});
|
||||
|
||||
function handleClose() {
|
||||
const prevIsMcpServers = previousRouteId === '/mcp-servers';
|
||||
if (browser && window.history.length > 1 && !prevIsMcpServers) {
|
||||
history.back();
|
||||
} else {
|
||||
goto(ROUTES.START);
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
if (page.url.searchParams.has('add')) {
|
||||
isAddingServer = true;
|
||||
@@ -54,15 +74,26 @@
|
||||
});
|
||||
</script>
|
||||
|
||||
<div in:fade={{ duration: 150 }} class="h-full max-h-[100dvh] overflow-y-auto">
|
||||
<div class="flex items-center gap-2 p-4 md:absolute md:top-8 md:left-8 md:px-0 md:py-2">
|
||||
<McpLogo class="h-5 w-5 md:h-6 md:w-6" />
|
||||
|
||||
<h1 class="text-xl font-semibold md:text-2xl">MCP Servers</h1>
|
||||
<div in:fade={{ duration: 150 }}>
|
||||
<div class="fixed top-4.5 right-4 z-50 md:hidden">
|
||||
<ActionIcon icon={X} tooltip="Close" onclick={handleClose} />
|
||||
</div>
|
||||
|
||||
<div class="sticky top-0 z-10 mt-4 flex items-start gap-4 p-4 md:justify-end md:px-8">
|
||||
<Button variant="outline" size="sm" class="shrink-0" onclick={() => (isAddingServer = true)}>
|
||||
<div
|
||||
class="sticky top-0 z-10 mt-4 mb-2 flex items-start gap-4 md:p-4 p-0 px-4 md:justify-between md:px-8"
|
||||
>
|
||||
<div class="flex items-center gap-2">
|
||||
<McpLogo class="h-5 w-5 md:h-6 md:w-6" />
|
||||
|
||||
<h1 class="text-lg font-semibold md:text-2xl">MCP Servers</h1>
|
||||
</div>
|
||||
|
||||
<Button
|
||||
variant="outline"
|
||||
size="lg"
|
||||
class="shrink-0 fixed md:static bottom-6 right-6"
|
||||
onclick={() => (isAddingServer = true)}
|
||||
>
|
||||
<Plus class="h-4 w-4" />
|
||||
|
||||
Add New Server
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
size: {
|
||||
default: 'h-9 px-4 py-2 has-[>svg]:px-3',
|
||||
sm: 'h-8 gap-1.5 rounded-md px-3 has-[>svg]:px-2.5',
|
||||
lg: 'h-10 rounded-md px-6 has-[>svg]:px-4',
|
||||
lg: 'h-10 rounded-lg px-6 has-[>svg]:px-4',
|
||||
'icon-lg': 'size-10',
|
||||
icon: 'size-9',
|
||||
'icon-sm': 'size-5 rounded-sm'
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
export const SIDEBAR_COOKIE_NAME = 'sidebar:state';
|
||||
export const SIDEBAR_COOKIE_MAX_AGE = 60 * 60 * 24 * 7;
|
||||
export const SIDEBAR_MIN_WIDTH = '18rem';
|
||||
export const SIDEBAR_MAX_WIDTH = '32rem';
|
||||
export const SIDEBAR_WIDTH_MOBILE = '18rem';
|
||||
export const SIDEBAR_WIDTH_ICON = '3rem';
|
||||
export const SIDEBAR_KEYBOARD_SHORTCUT = 'b';
|
||||
@@ -1,79 +0,0 @@
|
||||
import { isMobile } from '$lib/stores/viewport.svelte.js';
|
||||
import { getContext, setContext } from 'svelte';
|
||||
import { SIDEBAR_KEYBOARD_SHORTCUT, SIDEBAR_MIN_WIDTH } from './constants.js';
|
||||
|
||||
type Getter<T> = () => T;
|
||||
|
||||
export type SidebarStateProps = {
|
||||
/**
|
||||
* A getter function that returns the current open state of the sidebar.
|
||||
* We use a getter function here to support `bind:open` on the `Sidebar.Provider`
|
||||
* component.
|
||||
*/
|
||||
open: Getter<boolean>;
|
||||
|
||||
/**
|
||||
* A function that sets the open state of the sidebar. To support `bind:open`, we need
|
||||
* a source of truth for changing the open state to ensure it will be synced throughout
|
||||
* the sub-components and any `bind:` references.
|
||||
*/
|
||||
setOpen: (open: boolean) => void;
|
||||
};
|
||||
|
||||
class SidebarState {
|
||||
readonly props: SidebarStateProps;
|
||||
open = $derived.by(() => this.props.open());
|
||||
openMobile = $state(false);
|
||||
sidebarWidth = $state(SIDEBAR_MIN_WIDTH);
|
||||
isResizing = $state(false);
|
||||
setOpen: SidebarStateProps['setOpen'];
|
||||
state = $derived.by(() => (this.open ? 'expanded' : 'collapsed'));
|
||||
|
||||
constructor(props: SidebarStateProps) {
|
||||
this.setOpen = props.setOpen;
|
||||
this.props = props;
|
||||
}
|
||||
|
||||
// Convenience getter for checking if the sidebar is mobile
|
||||
// without this, we would need to use `sidebar.isMobile.current` everywhere
|
||||
get isMobile() {
|
||||
return isMobile.current;
|
||||
}
|
||||
|
||||
// Event handler to apply to the `<svelte:window>`
|
||||
handleShortcutKeydown = (e: KeyboardEvent) => {
|
||||
if (e.key === SIDEBAR_KEYBOARD_SHORTCUT && (e.metaKey || e.ctrlKey)) {
|
||||
e.preventDefault();
|
||||
this.toggle();
|
||||
}
|
||||
};
|
||||
|
||||
setOpenMobile = (value: boolean) => {
|
||||
this.openMobile = value;
|
||||
};
|
||||
|
||||
toggle = () => {
|
||||
this.setOpen(!this.open);
|
||||
};
|
||||
}
|
||||
|
||||
const SYMBOL_KEY = 'scn-sidebar';
|
||||
|
||||
/**
|
||||
* Instantiates a new `SidebarState` instance and sets it in the context.
|
||||
*
|
||||
* @param props The constructor props for the `SidebarState` class.
|
||||
* @returns The `SidebarState` instance.
|
||||
*/
|
||||
export function setSidebar(props: SidebarStateProps): SidebarState {
|
||||
return setContext(Symbol.for(SYMBOL_KEY), new SidebarState(props));
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the `SidebarState` instance from the context. This is a class instance,
|
||||
* so you cannot destructure it.
|
||||
* @returns The `SidebarState` instance.
|
||||
*/
|
||||
export function useSidebar(): SidebarState {
|
||||
return getContext(Symbol.for(SYMBOL_KEY));
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
import { useSidebar } from './context.svelte.js';
|
||||
import Content from './sidebar-content.svelte';
|
||||
import Footer from './sidebar-footer.svelte';
|
||||
import GroupAction from './sidebar-group-action.svelte';
|
||||
import GroupContent from './sidebar-group-content.svelte';
|
||||
import GroupLabel from './sidebar-group-label.svelte';
|
||||
import Group from './sidebar-group.svelte';
|
||||
import Header from './sidebar-header.svelte';
|
||||
import Input from './sidebar-input.svelte';
|
||||
import Inset from './sidebar-inset.svelte';
|
||||
import MenuAction from './sidebar-menu-action.svelte';
|
||||
import MenuBadge from './sidebar-menu-badge.svelte';
|
||||
import MenuButton from './sidebar-menu-button.svelte';
|
||||
import MenuItem from './sidebar-menu-item.svelte';
|
||||
import MenuSkeleton from './sidebar-menu-skeleton.svelte';
|
||||
import MenuSubButton from './sidebar-menu-sub-button.svelte';
|
||||
import MenuSubItem from './sidebar-menu-sub-item.svelte';
|
||||
import MenuSub from './sidebar-menu-sub.svelte';
|
||||
import Menu from './sidebar-menu.svelte';
|
||||
import Provider from './sidebar-provider.svelte';
|
||||
import Rail from './sidebar-rail.svelte';
|
||||
import Separator from './sidebar-separator.svelte';
|
||||
import Trigger from './sidebar-trigger.svelte';
|
||||
import Root from './sidebar.svelte';
|
||||
|
||||
export {
|
||||
Content,
|
||||
Footer,
|
||||
Group,
|
||||
GroupAction,
|
||||
GroupContent,
|
||||
GroupLabel,
|
||||
Header,
|
||||
Input,
|
||||
Inset,
|
||||
Menu,
|
||||
MenuAction,
|
||||
MenuBadge,
|
||||
MenuButton,
|
||||
MenuItem,
|
||||
MenuSkeleton,
|
||||
MenuSub,
|
||||
MenuSubButton,
|
||||
MenuSubItem,
|
||||
Provider,
|
||||
Rail,
|
||||
Root,
|
||||
Separator,
|
||||
//
|
||||
Root as Sidebar,
|
||||
Content as SidebarContent,
|
||||
Footer as SidebarFooter,
|
||||
Group as SidebarGroup,
|
||||
GroupAction as SidebarGroupAction,
|
||||
GroupContent as SidebarGroupContent,
|
||||
GroupLabel as SidebarGroupLabel,
|
||||
Header as SidebarHeader,
|
||||
Input as SidebarInput,
|
||||
Inset as SidebarInset,
|
||||
Menu as SidebarMenu,
|
||||
MenuAction as SidebarMenuAction,
|
||||
MenuBadge as SidebarMenuBadge,
|
||||
MenuButton as SidebarMenuButton,
|
||||
MenuItem as SidebarMenuItem,
|
||||
MenuSkeleton as SidebarMenuSkeleton,
|
||||
MenuSub as SidebarMenuSub,
|
||||
MenuSubButton as SidebarMenuSubButton,
|
||||
MenuSubItem as SidebarMenuSubItem,
|
||||
Provider as SidebarProvider,
|
||||
Rail as SidebarRail,
|
||||
Separator as SidebarSeparator,
|
||||
Trigger as SidebarTrigger,
|
||||
Trigger,
|
||||
useSidebar
|
||||
};
|
||||
@@ -1,24 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="sidebar-content"
|
||||
data-sidebar="content"
|
||||
class={cn(
|
||||
'flex min-h-0 flex-1 flex-col gap-2 overflow-auto group-data-[collapsible=icon]:overflow-hidden',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
@@ -1,21 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="sidebar-footer"
|
||||
data-sidebar="footer"
|
||||
class={cn('flex flex-col gap-2 p-3', className)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
@@ -1,36 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
import type { Snippet } from 'svelte';
|
||||
import type { HTMLButtonAttributes } from 'svelte/elements';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
child,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLButtonAttributes> & {
|
||||
child?: Snippet<[{ props: Record<string, unknown> }]>;
|
||||
} = $props();
|
||||
|
||||
const mergedProps = $derived({
|
||||
class: cn(
|
||||
'text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground outline-hidden absolute right-3 top-3.5 flex aspect-square w-5 items-center justify-center rounded-md p-0 transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0',
|
||||
// Increases the hit area of the button on mobile.
|
||||
'after:absolute after:-inset-2 md:after:hidden',
|
||||
'group-data-[collapsible=icon]:hidden',
|
||||
className
|
||||
),
|
||||
'data-slot': 'sidebar-group-action',
|
||||
'data-sidebar': 'group-action',
|
||||
...restProps
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if child}
|
||||
{@render child({ props: mergedProps })}
|
||||
{:else}
|
||||
<button bind:this={ref} {...mergedProps}>
|
||||
{@render children?.()}
|
||||
</button>
|
||||
{/if}
|
||||
@@ -1,21 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLDivElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="sidebar-group-content"
|
||||
data-sidebar="group-content"
|
||||
class={cn('w-full text-sm', className)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
@@ -1,34 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
import type { Snippet } from 'svelte';
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
children,
|
||||
child,
|
||||
class: className,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLElement>> & {
|
||||
child?: Snippet<[{ props: Record<string, unknown> }]>;
|
||||
} = $props();
|
||||
|
||||
const mergedProps = $derived({
|
||||
class: cn(
|
||||
'text-sidebar-foreground/70 ring-sidebar-ring outline-hidden flex h-8 shrink-0 items-center rounded-md px-2 text-xs font-medium transition-[margin,opacity] duration-200 ease-linear focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0',
|
||||
'group-data-[collapsible=icon]:-mt-8 group-data-[collapsible=icon]:opacity-0',
|
||||
className
|
||||
),
|
||||
'data-slot': 'sidebar-group-label',
|
||||
'data-sidebar': 'group-label',
|
||||
...restProps
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if child}
|
||||
{@render child({ props: mergedProps })}
|
||||
{:else}
|
||||
<div bind:this={ref} {...mergedProps}>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
{/if}
|
||||
@@ -1,21 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="sidebar-group"
|
||||
data-sidebar="group"
|
||||
class={cn('relative flex w-full min-w-0 flex-col p-2', className)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
@@ -1,21 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="sidebar-header"
|
||||
data-sidebar="header"
|
||||
class={cn('flex flex-col gap-2 p-2', className)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
@@ -1,21 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { ComponentProps } from 'svelte';
|
||||
import { Input } from '$lib/components/ui/input/index.js';
|
||||
import { cn } from '$lib/components/ui/utils.js';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
value = $bindable(''),
|
||||
class: className,
|
||||
...restProps
|
||||
}: ComponentProps<typeof Input> = $props();
|
||||
</script>
|
||||
|
||||
<Input
|
||||
bind:ref
|
||||
bind:value
|
||||
data-slot="sidebar-input"
|
||||
data-sidebar="input"
|
||||
class={cn('h-8 w-full bg-background shadow-none', className)}
|
||||
{...restProps}
|
||||
/>
|
||||
@@ -1,24 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
|
||||
</script>
|
||||
|
||||
<main
|
||||
bind:this={ref}
|
||||
data-slot="sidebar-inset"
|
||||
class={cn(
|
||||
'relative flex w-full flex-1 flex-col',
|
||||
'md:peer-data-[variant=inset]:m-2 md:peer-data-[variant=inset]:ml-0 md:peer-data-[variant=inset]:rounded-xl md:peer-data-[variant=inset]:shadow-sm md:peer-data-[variant=inset]:peer-data-[state=collapsed]:ml-2',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</main>
|
||||
@@ -1,43 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
import type { Snippet } from 'svelte';
|
||||
import type { HTMLButtonAttributes } from 'svelte/elements';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
showOnHover = false,
|
||||
children,
|
||||
child,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLButtonAttributes> & {
|
||||
child?: Snippet<[{ props: Record<string, unknown> }]>;
|
||||
showOnHover?: boolean;
|
||||
} = $props();
|
||||
|
||||
const mergedProps = $derived({
|
||||
class: cn(
|
||||
'text-sidebar-foreground ring-sidebar-ring hover:bg-sidebar-accent hover:text-sidebar-accent-foreground peer-hover/menu-button:text-sidebar-accent-foreground outline-hidden absolute right-1 top-1.5 flex aspect-square w-5 items-center justify-center rounded-md p-0 transition-transform focus-visible:ring-2 [&>svg]:size-4 [&>svg]:shrink-0',
|
||||
// Increases the hit area of the button on mobile.
|
||||
'after:absolute after:-inset-2 md:after:hidden',
|
||||
'peer-data-[size=sm]/menu-button:top-1',
|
||||
'peer-data-[size=default]/menu-button:top-1.5',
|
||||
'peer-data-[size=lg]/menu-button:top-2.5',
|
||||
'group-data-[collapsible=icon]:hidden',
|
||||
showOnHover &&
|
||||
'peer-data-[active=true]/menu-button:text-sidebar-accent-foreground group-focus-within/menu-item:opacity-100 group-hover/menu-item:opacity-100 data-[state=open]:opacity-100 md:opacity-0',
|
||||
className
|
||||
),
|
||||
'data-slot': 'sidebar-menu-action',
|
||||
'data-sidebar': 'menu-action',
|
||||
...restProps
|
||||
});
|
||||
</script>
|
||||
|
||||
{#if child}
|
||||
{@render child({ props: mergedProps })}
|
||||
{:else}
|
||||
<button bind:this={ref} {...mergedProps}>
|
||||
{@render children?.()}
|
||||
</button>
|
||||
{/if}
|
||||
@@ -1,29 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { cn, type WithElementRef } from '$lib/components/ui/utils.js';
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
|
||||
let {
|
||||
ref = $bindable(null),
|
||||
class: className,
|
||||
children,
|
||||
...restProps
|
||||
}: WithElementRef<HTMLAttributes<HTMLElement>> = $props();
|
||||
</script>
|
||||
|
||||
<div
|
||||
bind:this={ref}
|
||||
data-slot="sidebar-menu-badge"
|
||||
data-sidebar="menu-badge"
|
||||
class={cn(
|
||||
'pointer-events-none absolute right-1 flex h-5 min-w-5 items-center justify-center rounded-md px-1 text-xs font-medium text-sidebar-foreground tabular-nums select-none',
|
||||
'peer-hover/menu-button:text-sidebar-accent-foreground peer-data-[active=true]/menu-button:text-sidebar-accent-foreground',
|
||||
'peer-data-[size=sm]/menu-button:top-1',
|
||||
'peer-data-[size=default]/menu-button:top-1.5',
|
||||
'peer-data-[size=lg]/menu-button:top-2.5',
|
||||
'group-data-[collapsible=icon]:hidden',
|
||||
className
|
||||
)}
|
||||
{...restProps}
|
||||
>
|
||||
{@render children?.()}
|
||||
</div>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user