mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-14 11:07:10 +00:00
vulkan : refactor buffer handling in vk_op_f32 (#16840)
* vulkan : refactor/simplify buffer handling in vk_op_* functions * Combine UMA handling into ggml_vk_tensor_subbuffer
This commit is contained in:
@@ -5387,7 +5387,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
|
||||
device->pinned_memory.erase(device->pinned_memory.begin() + index);
|
||||
}
|
||||
|
||||
static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
|
||||
static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
|
||||
std::lock_guard<std::recursive_mutex> guard(device->mutex);
|
||||
buf = nullptr;
|
||||
buf_offset = 0;
|
||||
@@ -5402,6 +5402,32 @@ static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf
|
||||
}
|
||||
}
|
||||
|
||||
static vk_subbuffer ggml_vk_tensor_subbuffer(
|
||||
const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
|
||||
|
||||
vk_buffer buffer = nullptr;
|
||||
size_t offset = 0;
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
|
||||
}
|
||||
if (!buffer) {
|
||||
auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
|
||||
buffer = buf_ctx->dev_buffer;
|
||||
offset = vk_tensor_offset(tensor) + tensor->view_offs;
|
||||
}
|
||||
GGML_ASSERT(buffer != nullptr);
|
||||
|
||||
size_t size = ggml_nbytes(tensor);
|
||||
|
||||
size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
// The shader must support misaligned offsets when indexing into the buffer
|
||||
GGML_ASSERT(allow_misalign || misalign_bytes == 0);
|
||||
offset &= ~misalign_bytes;
|
||||
size += misalign_bytes;
|
||||
|
||||
return vk_subbuffer{buffer, offset, size};
|
||||
}
|
||||
|
||||
static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
|
||||
vk_submission s;
|
||||
s.buffer = ggml_vk_create_cmd_buffer(device, p);
|
||||
@@ -7953,72 +7979,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr;
|
||||
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0;
|
||||
|
||||
bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false;
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
|
||||
ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
|
||||
ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
|
||||
ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
|
||||
Q_uma = d_Q != nullptr;
|
||||
K_uma = d_K != nullptr;
|
||||
V_uma = d_V != nullptr;
|
||||
D_uma = d_D != nullptr;
|
||||
if (mask) {
|
||||
ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
|
||||
M_uma = d_M != nullptr;
|
||||
}
|
||||
if (sinks) {
|
||||
ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset);
|
||||
S_uma = d_S != nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context;
|
||||
ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
|
||||
ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
|
||||
|
||||
if (!Q_uma) {
|
||||
d_Q = q_buf_ctx->dev_buffer;
|
||||
q_buf_offset = vk_tensor_offset(q) + q->view_offs;
|
||||
}
|
||||
if (!K_uma) {
|
||||
d_K = k_buf_ctx->dev_buffer;
|
||||
k_buf_offset = vk_tensor_offset(k) + k->view_offs;
|
||||
}
|
||||
if (!V_uma) {
|
||||
d_V = v_buf_ctx->dev_buffer;
|
||||
v_buf_offset = vk_tensor_offset(v) + v->view_offs;
|
||||
}
|
||||
if (!D_uma) {
|
||||
d_D = d_buf_ctx->dev_buffer;
|
||||
d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
}
|
||||
|
||||
if (!M_uma) {
|
||||
d_M = d_Q;
|
||||
m_buf_offset = q_buf_offset;
|
||||
if (mask) {
|
||||
ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context;
|
||||
d_M = m_buf_ctx->dev_buffer;
|
||||
m_buf_offset = vk_tensor_offset(mask) + mask->view_offs;
|
||||
}
|
||||
}
|
||||
|
||||
if (!S_uma) {
|
||||
d_S = d_Q;
|
||||
s_buf_offset = q_buf_offset;
|
||||
if (sinks) {
|
||||
ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context;
|
||||
d_S = s_buf_ctx->dev_buffer;
|
||||
s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs;
|
||||
}
|
||||
}
|
||||
vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, q);
|
||||
vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, k);
|
||||
vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, v);
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
|
||||
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
|
||||
|
||||
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
|
||||
|
||||
@@ -8040,15 +8006,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
}
|
||||
|
||||
vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
ggml_vk_subbuffer(ctx, d_Q, q_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_K, k_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_V, v_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_M, m_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0),
|
||||
},
|
||||
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
|
||||
// We only use split_k when group query attention is enabled, which means
|
||||
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
||||
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
||||
@@ -8058,23 +8018,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
||||
{
|
||||
ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0),
|
||||
ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_D, d_buf_offset),
|
||||
},
|
||||
{split_k_buf, sinks_buf, dst_buf},
|
||||
pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
|
||||
ctx->prealloc_split_k_need_sync = true;
|
||||
} else {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
ggml_vk_subbuffer(ctx, d_Q, q_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_K, k_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_V, v_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_M, m_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_S, s_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_D, d_buf_offset),
|
||||
},
|
||||
{q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
|
||||
pc, { workgroups_x, workgroups_y, workgroups_z });
|
||||
}
|
||||
}
|
||||
@@ -8757,35 +8706,15 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
const uint64_t ne01 = src0->ne[1];
|
||||
const uint64_t ne02 = src0->ne[2];
|
||||
const uint64_t ne03 = src0->ne[3];
|
||||
const uint64_t ne0 = ne00 * ne01;
|
||||
|
||||
const bool use_src1 = src1 != nullptr;
|
||||
const uint64_t ne10 = use_src1 ? src1->ne[0] : 0;
|
||||
const uint64_t ne11 = use_src1 ? src1->ne[1] : 0;
|
||||
const uint64_t ne12 = use_src1 ? src1->ne[2] : 0;
|
||||
const uint64_t ne13 = use_src1 ? src1->ne[3] : 0;
|
||||
const uint64_t ne1 = ne10 * ne11;
|
||||
// const uint64_t nb10 = use_src1 ? src1->nb[0] : 0;
|
||||
|
||||
const bool use_src2 = src2 != nullptr;
|
||||
const uint64_t ne20 = use_src2 ? src2->ne[0] : 0;
|
||||
const uint64_t ne21 = use_src2 ? src2->ne[1] : 0;
|
||||
const uint64_t ne22 = use_src2 ? src2->ne[2] : 0;
|
||||
const uint64_t ne23 = use_src2 ? src2->ne[3] : 0;
|
||||
const uint64_t ne2 = ne20 * ne21;
|
||||
|
||||
const bool use_src3 = src3 != nullptr;
|
||||
const uint64_t ne30 = use_src3 ? src3->ne[0] : 0;
|
||||
const uint64_t ne31 = use_src3 ? src3->ne[1] : 0;
|
||||
const uint64_t ne32 = use_src3 ? src3->ne[2] : 0;
|
||||
const uint64_t ne33 = use_src3 ? src3->ne[3] : 0;
|
||||
const uint64_t ne3 = ne30 * ne31;
|
||||
|
||||
const uint64_t ned0 = dst->ne[0];
|
||||
const uint64_t ned1 = dst->ne[1];
|
||||
const uint64_t ned2 = dst->ne[2];
|
||||
const uint64_t ned3 = dst->ne[3];
|
||||
const uint64_t ned = ned0 * ned1;
|
||||
|
||||
init_pushconst_fastdiv(pc);
|
||||
|
||||
@@ -8804,74 +8733,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
|
||||
const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
|
||||
ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
|
||||
ggml_backend_vk_buffer_context * src3_buf_ctx = use_src3 ? (ggml_backend_vk_buffer_context *)src3->buffer->context : nullptr;
|
||||
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous);
|
||||
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
|
||||
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
|
||||
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
|
||||
|
||||
vk_buffer d_X = nullptr;
|
||||
size_t x_buf_offset = 0;
|
||||
vk_buffer d_Y = nullptr;
|
||||
size_t y_buf_offset = 0;
|
||||
vk_buffer d_Z = nullptr;
|
||||
size_t z_buf_offset = 0;
|
||||
vk_buffer d_W = nullptr;
|
||||
size_t w_buf_offset = 0;
|
||||
|
||||
bool src0_uma = false;
|
||||
bool src1_uma = false;
|
||||
bool src2_uma = false;
|
||||
bool src3_uma = false;
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
|
||||
src0_uma = d_X != nullptr;
|
||||
if (use_src1) {
|
||||
ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset);
|
||||
src1_uma = d_Y != nullptr;
|
||||
}
|
||||
if (use_src2) {
|
||||
ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
|
||||
src2_uma = d_Z != nullptr;
|
||||
}
|
||||
if (use_src3) {
|
||||
ggml_vk_host_get(ctx->device, src3->data, d_W, w_buf_offset);
|
||||
src3_uma = d_W != nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
vk_buffer d_D = dst_buf_ctx->dev_buffer;
|
||||
|
||||
GGML_ASSERT(d_D != nullptr);
|
||||
uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
if(!src0_uma) {
|
||||
d_X = src0_buf_ctx->dev_buffer;
|
||||
x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||
GGML_ASSERT(d_X != nullptr);
|
||||
}
|
||||
if (use_src1 && !src1_uma) {
|
||||
d_Y = src1_buf_ctx->dev_buffer;
|
||||
y_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
|
||||
GGML_ASSERT(d_Y != nullptr);
|
||||
}
|
||||
if (use_src2 && !src2_uma) {
|
||||
d_Z = src2_buf_ctx->dev_buffer;
|
||||
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
|
||||
GGML_ASSERT(d_Z != nullptr);
|
||||
}
|
||||
if (use_src3 && !src3_uma) {
|
||||
d_W = src3_buf_ctx->dev_buffer;
|
||||
w_buf_offset = vk_tensor_offset(src3) + src3->view_offs;
|
||||
GGML_ASSERT(d_W != nullptr);
|
||||
}
|
||||
// Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
|
||||
// Compute misalignment offset for descriptors and store it in in push constants.
|
||||
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
|
||||
x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
w_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||
|
||||
std::array<uint32_t, 3> elements;
|
||||
|
||||
@@ -8955,9 +8824,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
const uint32_t KH = ne01;
|
||||
const uint32_t KW = ne00;
|
||||
|
||||
const uint32_t OD = ned3 / N;
|
||||
const uint32_t OH = ned2;
|
||||
const uint32_t OW = ned1;
|
||||
const uint32_t OD = dst->ne[3] / N;
|
||||
const uint32_t OH = dst->ne[2];
|
||||
const uint32_t OW = dst->ne[1];
|
||||
|
||||
const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
|
||||
const uint32_t N_OD_OH = N*OD*OH;
|
||||
@@ -9072,112 +8941,50 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
break;
|
||||
}
|
||||
|
||||
uint64_t x_sz, y_sz, z_sz, w_sz, d_sz;
|
||||
|
||||
if (op_supports_incontiguous) {
|
||||
x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
|
||||
y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
|
||||
z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
|
||||
w_sz = use_src3 ? ggml_nbytes(src3) + get_misalign_bytes(ctx, src3) : 0;
|
||||
d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
|
||||
|
||||
if (x_buf_offset + x_sz >= d_X->size) {
|
||||
x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset);
|
||||
}
|
||||
if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
|
||||
y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset);
|
||||
}
|
||||
if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
|
||||
z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset);
|
||||
}
|
||||
if (use_src3 && w_buf_offset + w_sz >= d_W->size) {
|
||||
w_sz = ggml_vk_get_max_buffer_range(ctx, d_W, w_buf_offset);
|
||||
}
|
||||
if (d_buf_offset + d_sz >= d_D->size) {
|
||||
d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset);
|
||||
}
|
||||
} else {
|
||||
x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03;
|
||||
y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0;
|
||||
z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0;
|
||||
w_sz = use_src3 ? ggml_type_size(src3->type) * ne3 * ne32 * ne33 : 0;
|
||||
d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3;
|
||||
}
|
||||
|
||||
if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
|
||||
vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
|
||||
size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
|
||||
vk_subbuffer a_buf = src0_buf;
|
||||
if (ctx->do_add_rms_partials) {
|
||||
a_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset);
|
||||
}
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{ vk_subbuffer{ d_X, x_buf_offset, x_sz },
|
||||
vk_subbuffer{ d_Y, y_buf_offset, y_sz },
|
||||
vk_subbuffer{ d_D, d_buf_offset, d_sz },
|
||||
ggml_vk_subbuffer(ctx, d_A, a_buf_offset),
|
||||
}, pc, elements);
|
||||
{ src0_buf, src1_buf, dst_buf, a_buf }, pc, elements);
|
||||
} else if (op == GGML_OP_GLU) {
|
||||
// Empty src1 is possible in glu, but the shader needs a buffer
|
||||
vk_subbuffer subbuf_y;
|
||||
if (use_src1) {
|
||||
subbuf_y = { d_Y, y_buf_offset, y_sz };
|
||||
} else {
|
||||
subbuf_y = { d_X, 0, x_sz };
|
||||
}
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements);
|
||||
} else if (op == GGML_OP_SOFT_MAX) {
|
||||
// Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
|
||||
vk_subbuffer subbuf_y;
|
||||
if (use_src1) {
|
||||
subbuf_y = { d_Y, y_buf_offset, y_sz };
|
||||
} else {
|
||||
subbuf_y = { d_X, 0, x_sz };
|
||||
}
|
||||
|
||||
vk_subbuffer subbuf_z;
|
||||
if (use_src2) {
|
||||
subbuf_z = { d_Z, z_buf_offset, z_sz };
|
||||
} else {
|
||||
subbuf_z = { d_X, 0, x_sz };
|
||||
}
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
|
||||
vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements);
|
||||
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
|
||||
// Empty src2 is possible in rope, but the shader needs a buffer
|
||||
vk_subbuffer subbuf_z, subbuf_w;
|
||||
if (use_src2) {
|
||||
subbuf_z = { d_Z, z_buf_offset, z_sz };
|
||||
} else {
|
||||
subbuf_z = { d_X, 0, x_sz };
|
||||
}
|
||||
if (use_src3) {
|
||||
subbuf_w = { d_W, w_buf_offset, w_sz };
|
||||
} else {
|
||||
subbuf_w = { d_X, 0, x_sz };
|
||||
}
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz }, subbuf_w }, pc, elements);
|
||||
// Empty src2 and src3 is possible in rope, but the shader needs a buffer
|
||||
vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
|
||||
vk_subbuffer subbuf3 = use_src3 ? src3_buf : src0_buf;
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf, subbuf3 }, pc, elements);
|
||||
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
|
||||
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
|
||||
// buffer device address path doesn't use dst buffer
|
||||
d_sz = 1;
|
||||
dst_buf.size = 1;
|
||||
}
|
||||
// im2col uses only src1 and dst buffers
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements);
|
||||
} else if (op == GGML_OP_COUNT_EQUAL) {
|
||||
// count_equal assumes that destination buffer is initialized with zeroes
|
||||
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
||||
ggml_vk_buffer_memset_async(subctx, dst_buf.buffer, dst_buf.offset, 0, dst_buf.size);
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
|
||||
} else if (op == GGML_OP_OPT_STEP_SGD) {
|
||||
// OPT_STEP_SGD works on src0, it does not need dst
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements);
|
||||
} else if (use_src3) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_W, w_buf_offset, w_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, src3_buf, dst_buf }, pc, elements);
|
||||
} else if (use_src2) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements);
|
||||
} else if (use_src1) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
|
||||
} else {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9413,39 +9220,10 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
vk_subbuffer src_buf[7] = {};
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
||||
}
|
||||
|
||||
vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
|
||||
size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
||||
bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
|
||||
|
||||
if (ctx->device->uma) {
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
||||
srcs_uma[i] = d_srcs[i] != nullptr;
|
||||
}
|
||||
|
||||
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||
dst_uma = d_D != nullptr;
|
||||
}
|
||||
|
||||
uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
|
||||
for (int i = 0; i < num_srcs; i++) {
|
||||
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
||||
if (!srcs_uma[i]) {
|
||||
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
||||
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
||||
}
|
||||
}
|
||||
|
||||
const uint64_t dst_size = ggml_nbytes(dst);
|
||||
if (!dst_uma) {
|
||||
d_D = dst_buf_ctx->dev_buffer;
|
||||
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
|
||||
}
|
||||
|
||||
std::array<uint32_t, 3> elements = {
|
||||
@@ -9455,26 +9233,13 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
};
|
||||
|
||||
if (version == 6) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
|
||||
pc, elements);
|
||||
} else if (version == 7) {
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
|
||||
pc, elements);
|
||||
} else {
|
||||
// shouldn't happen
|
||||
GGML_ASSERT(false);
|
||||
@@ -9554,40 +9319,10 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
n_head, head_dim, n_group, n_tok
|
||||
};
|
||||
|
||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||
ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
|
||||
}
|
||||
|
||||
vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
|
||||
size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
|
||||
bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };
|
||||
|
||||
if (ctx->device->uma) {
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
|
||||
srcs_uma[i] = d_srcs[i] != nullptr;
|
||||
}
|
||||
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
|
||||
dst_uma = d_D != nullptr;
|
||||
}
|
||||
|
||||
if (!dst_uma) {
|
||||
d_D = dst_buf_ctx->dev_buffer;
|
||||
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
if (!srcs_uma[i]) {
|
||||
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
|
||||
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
|
||||
}
|
||||
}
|
||||
|
||||
size_t dst_size = ggml_nbytes(dst);
|
||||
size_t src_sizes[GGML_MAX_SRC];
|
||||
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
|
||||
src_sizes[i] = ggml_nbytes(dst->src[i]);
|
||||
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
|
||||
vk_subbuffer src_buf[7] = {};
|
||||
for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) {
|
||||
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
|
||||
}
|
||||
|
||||
std::array<uint32_t, 3> elements;
|
||||
@@ -9597,16 +9332,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
const uint32_t num_workgroups_y = n_seq;
|
||||
elements = { num_workgroups_x, num_workgroups_y, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
|
||||
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
|
||||
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
|
||||
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
|
||||
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
|
||||
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
|
||||
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
|
||||
vk_subbuffer{ d_D, dst_offset, dst_size }
|
||||
}, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
|
||||
pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
@@ -9653,66 +9381,17 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
|
||||
ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
|
||||
ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
|
||||
ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
|
||||
ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
|
||||
|
||||
vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
|
||||
size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
|
||||
bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
|
||||
ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
|
||||
ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
|
||||
ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
|
||||
ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
|
||||
|
||||
X_uma = d_X != nullptr;
|
||||
G_uma = d_G != nullptr;
|
||||
GM_uma = d_GM != nullptr;
|
||||
GV_uma = d_GV != nullptr;
|
||||
P_uma = d_P != nullptr;
|
||||
}
|
||||
|
||||
if (!X_uma) {
|
||||
d_X = x_buf_ctx->dev_buffer;
|
||||
x_offset = vk_tensor_offset(x) + x->view_offs;
|
||||
}
|
||||
if (!G_uma) {
|
||||
d_G = g_buf_ctx->dev_buffer;
|
||||
g_offset = vk_tensor_offset(g) + g->view_offs;
|
||||
}
|
||||
if (!GM_uma) {
|
||||
d_GM = gm_buf_ctx->dev_buffer;
|
||||
gm_offset = vk_tensor_offset(gm) + gm->view_offs;
|
||||
}
|
||||
if (!GV_uma) {
|
||||
d_GV = gv_buf_ctx->dev_buffer;
|
||||
gv_offset = vk_tensor_offset(gv) + gv->view_offs;
|
||||
}
|
||||
if (!P_uma) {
|
||||
d_P = p_buf_ctx->dev_buffer;
|
||||
p_offset = vk_tensor_offset(p) + p->view_offs;
|
||||
}
|
||||
|
||||
const uint64_t x_size = ggml_nbytes(x);
|
||||
const uint64_t g_size = ggml_nbytes(g);
|
||||
const uint64_t gm_size = ggml_nbytes(gm);
|
||||
const uint64_t gv_size = ggml_nbytes(gv);
|
||||
const uint64_t p_size = ggml_nbytes(p);
|
||||
vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);
|
||||
vk_subbuffer g_buf = ggml_vk_tensor_subbuffer(ctx, g);
|
||||
vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer(ctx, gm);
|
||||
vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer(ctx, gv);
|
||||
vk_subbuffer p_buf = ggml_vk_tensor_subbuffer(ctx, p);
|
||||
|
||||
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
|
||||
vk_subbuffer{ d_X, x_offset, x_size },
|
||||
vk_subbuffer{ d_G, g_offset, g_size },
|
||||
vk_subbuffer{ d_GM, gm_offset, gm_size },
|
||||
vk_subbuffer{ d_GV, gv_offset, gv_size },
|
||||
vk_subbuffer{ d_P, p_offset, p_size },
|
||||
}, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{x_buf, g_buf, gm_buf, gv_buf, p_buf},
|
||||
pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
|
||||
@@ -10044,45 +9723,9 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
|
||||
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
||||
|
||||
ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
|
||||
ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
|
||||
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
|
||||
|
||||
vk_buffer d_logits = nullptr;
|
||||
size_t logits_buf_offset = 0;
|
||||
vk_buffer d_weights = nullptr;
|
||||
size_t weights_buf_offset = 0;
|
||||
vk_buffer d_ids = nullptr;
|
||||
size_t ids_buf_offset = 0;
|
||||
|
||||
bool logits_uma = false;
|
||||
bool weights_uma = false;
|
||||
bool ids_uma = false;
|
||||
|
||||
if (ctx->device->uma) {
|
||||
ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
|
||||
ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
|
||||
ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
|
||||
logits_uma = d_logits != nullptr;
|
||||
weights_uma = d_weights != nullptr;
|
||||
ids_uma = d_ids != nullptr;
|
||||
}
|
||||
|
||||
if (!logits_uma) {
|
||||
d_logits = logits_buf_ctx->dev_buffer;
|
||||
logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
|
||||
GGML_ASSERT(d_logits != nullptr);
|
||||
}
|
||||
if (!weights_uma) {
|
||||
d_weights = weights_buf_ctx->dev_buffer;
|
||||
weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
|
||||
GGML_ASSERT(d_weights != nullptr);
|
||||
}
|
||||
if (!ids_uma) {
|
||||
d_ids = ids_buf_ctx->dev_buffer;
|
||||
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
|
||||
GGML_ASSERT(d_ids != nullptr);
|
||||
}
|
||||
vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
|
||||
vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
|
||||
vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
|
||||
|
||||
vk_op_topk_moe_push_constants pc {};
|
||||
pc.n_rows = n_rows;
|
||||
@@ -10098,12 +9741,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||
const uint32_t rows_per_block = 4;
|
||||
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
|
||||
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
||||
{
|
||||
ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
|
||||
ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
|
||||
}, pc, elements);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
|
||||
}
|
||||
|
||||
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
|
||||
|
||||
Reference in New Issue
Block a user