mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	vulkan: Split large mul_mat_id to fit in shared memory (#14451)
This commit is contained in:
		@@ -5966,7 +5966,30 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
 | 
			
		||||
    if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
 | 
			
		||||
        ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
 | 
			
		||||
    } else {
 | 
			
		||||
        ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
 | 
			
		||||
        // Split based on number of ids, to fit in shared memory
 | 
			
		||||
        const uint32_t nei0 = (uint32_t)src2->ne[0];
 | 
			
		||||
        const uint32_t nei1 = (uint32_t)src2->ne[1];
 | 
			
		||||
 | 
			
		||||
        GGML_ASSERT(nei0 <= 4096);
 | 
			
		||||
        const uint32_t split_size = std::min(nei1, 4096u / nei0);
 | 
			
		||||
 | 
			
		||||
        ggml_tensor src1_copy = *src1;
 | 
			
		||||
        ggml_tensor src2_copy = *src2;
 | 
			
		||||
        ggml_tensor dst_copy = *dst;
 | 
			
		||||
 | 
			
		||||
        for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
 | 
			
		||||
            const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
 | 
			
		||||
 | 
			
		||||
            src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
 | 
			
		||||
            src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
 | 
			
		||||
            dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
 | 
			
		||||
 | 
			
		||||
            src1_copy.ne[2] = n_tokens;
 | 
			
		||||
            src2_copy.ne[1] = n_tokens;
 | 
			
		||||
            dst_copy.ne[2] = n_tokens;
 | 
			
		||||
 | 
			
		||||
            ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -10135,9 +10158,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
 | 
			
		||||
                ggml_type src0_type = op->src[0]->type;
 | 
			
		||||
                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
 | 
			
		||||
                const vk_device& device = ggml_vk_get_device(ctx->device);
 | 
			
		||||
                if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
 | 
			
		||||
                    // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
 | 
			
		||||
                    return false;
 | 
			
		||||
                if (op->op == GGML_OP_MUL_MAT_ID) {
 | 
			
		||||
                    if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
 | 
			
		||||
                        // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
 | 
			
		||||
                        return false;
 | 
			
		||||
                    }
 | 
			
		||||
                    // Check against size of shared memory variable
 | 
			
		||||
                    if (op->src[2]->ne[0] > 4096) {
 | 
			
		||||
                        return false;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                switch (src0_type) {
 | 
			
		||||
                    case GGML_TYPE_F32:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user