mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-16 11:27:03 +00:00
vulkan: Remove splitting for mul_mat_id (#15568)
row_ids only needs to hold the BN rows for the current tile.
This commit is contained in:
@@ -93,7 +93,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
||||
shared u16vec4 row_ids[4096];
|
||||
shared u16vec4 row_ids[BN];
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
B_TYPE b[];
|
||||
@@ -111,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
|
||||
return B_TYPE(0.0);
|
||||
}
|
||||
|
||||
const u16vec4 row_idx = row_ids[row_i];
|
||||
const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
|
||||
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
|
||||
|
||||
return ret;
|
||||
@@ -123,14 +123,14 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
|
||||
uint dc = ic * BN + c;
|
||||
|
||||
if (dr < p.M && dc < _ne1) {
|
||||
uint row_i = dc;
|
||||
uint row_i = c;
|
||||
const u16vec4 row_idx = row_ids[row_i];
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
@@ -180,11 +180,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
@@ -218,9 +221,9 @@ void main() {
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true);
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false);
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
|
||||
// Workgroup has no work
|
||||
|
||||
Reference in New Issue
Block a user