mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	rwkv6: add wkv6 support for Vulkan backend (#10829)
* rwkv_wkv6 vulkan shader * RWKV_WKV6 Vulkan op tests passed Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Apply code format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * add [[unroll]] and remove unnecessary conditions * add uma support * fix erros in EditorConfig Checker --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
		| @@ -245,6 +245,7 @@ struct vk_device_struct { | ||||
|     vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; | ||||
|     vk_pipeline pipeline_timestep_embedding_f32; | ||||
|     vk_pipeline pipeline_pool2d_f32; | ||||
|     vk_pipeline pipeline_rwkv_wkv6_f32; | ||||
|  | ||||
|     // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} | ||||
|     vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; | ||||
| @@ -528,6 +529,13 @@ struct vk_op_pool2d_push_constants { | ||||
|     int32_t p0; int32_t p1; | ||||
| }; | ||||
|  | ||||
| struct vk_op_rwkv_wkv6_push_constants { | ||||
|     uint32_t B; | ||||
|     uint32_t T; | ||||
|     uint32_t C; | ||||
|     uint32_t H; | ||||
| }; | ||||
|  | ||||
| // Allow pre-recording command buffers | ||||
| struct vk_staging_memcpy { | ||||
|     vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} | ||||
| @@ -2014,6 +2022,8 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|  | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); | ||||
|  | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); | ||||
|  | ||||
|     for (auto &c : compiles) { | ||||
|         c.wait(); | ||||
|     } | ||||
| @@ -5022,6 +5032,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const | ||||
|             return ctx->device->pipeline_pool2d_f32; | ||||
|         } | ||||
|         return nullptr; | ||||
|     case GGML_OP_RWKV_WKV6: | ||||
|         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||||
|             return ctx->device->pipeline_rwkv_wkv6_f32; | ||||
|         } | ||||
|         return nullptr; | ||||
|     case GGML_OP_LEAKY_RELU: | ||||
|         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||||
|             return ctx->device->pipeline_leaky_relu_f32; | ||||
| @@ -5424,6 +5439,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const | ||||
|     }, dryrun); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { | ||||
|     const ggml_tensor * k = dst->src[0]; | ||||
|     const ggml_tensor * v = dst->src[1]; | ||||
|     const ggml_tensor * r = dst->src[2]; | ||||
|     const ggml_tensor * tf = dst->src[3]; | ||||
|     const ggml_tensor * td = dst->src[4]; | ||||
|     const ggml_tensor * state = dst->src[5]; | ||||
|  | ||||
|     GGML_ASSERT(!ggml_is_quantized(k->type)); | ||||
|     GGML_ASSERT(!ggml_is_quantized(v->type)); | ||||
|     GGML_ASSERT(!ggml_is_quantized(r->type)); | ||||
|     GGML_ASSERT(!ggml_is_quantized(tf->type)); | ||||
|     GGML_ASSERT(!ggml_is_quantized(td->type)); | ||||
|     GGML_ASSERT(!ggml_is_quantized(state->type)); | ||||
|     GGML_ASSERT(dst->buffer != nullptr); | ||||
|  | ||||
|     vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); | ||||
|     GGML_ASSERT(pipeline != nullptr); | ||||
|  | ||||
|     if (dryrun) { | ||||
|         ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; | ||||
|     ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; | ||||
|     ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; | ||||
|     ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; | ||||
|     ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; | ||||
|     ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; | ||||
|     ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; | ||||
|  | ||||
|     ggml_vk_sync_buffers(subctx); | ||||
|  | ||||
|     vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State; | ||||
|     uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset; | ||||
|     bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; | ||||
|  | ||||
|     if (ctx->device->uma) { | ||||
|         ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); | ||||
|         ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); | ||||
|         ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); | ||||
|         ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); | ||||
|         ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); | ||||
|         ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); | ||||
|         ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); | ||||
|  | ||||
|         K_uma = d_K != nullptr; | ||||
|         V_uma = d_V != nullptr; | ||||
|         R_uma = d_R != nullptr; | ||||
|         TF_uma = d_TF != nullptr; | ||||
|         TD_uma = d_TD != nullptr; | ||||
|         STATE_uma = d_State != nullptr; | ||||
|         DST_uma = d_D != nullptr; | ||||
|     } | ||||
|  | ||||
|     if (!K_uma) { | ||||
|         d_K = k_buf_ctx->dev_buffer; | ||||
|         k_offset = vk_tensor_offset(k) + k->view_offs; | ||||
|     } | ||||
|     if (!V_uma) { | ||||
|         d_V = v_buf_ctx->dev_buffer; | ||||
|         v_offset = vk_tensor_offset(v) + v->view_offs; | ||||
|     } | ||||
|     if (!R_uma) { | ||||
|         d_R = r_buf_ctx->dev_buffer; | ||||
|         r_offset = vk_tensor_offset(r) + r->view_offs; | ||||
|     } | ||||
|     if (!TF_uma) { | ||||
|         d_TF = tf_buf_ctx->dev_buffer; | ||||
|         tf_offset = vk_tensor_offset(tf) + tf->view_offs; | ||||
|     } | ||||
|     if (!TD_uma) { | ||||
|         d_TD = td_buf_ctx->dev_buffer; | ||||
|         td_offset = vk_tensor_offset(td) + td->view_offs; | ||||
|     } | ||||
|     if (!STATE_uma) { | ||||
|         d_State = state_buf_ctx->dev_buffer; | ||||
|         state_offset = vk_tensor_offset(state) + state->view_offs; | ||||
|     } | ||||
|     if (!DST_uma) { | ||||
|         d_D = dst_buf_ctx->dev_buffer; | ||||
|         dst_offset = vk_tensor_offset(dst) + dst->view_offs; | ||||
|     } | ||||
|  | ||||
|     const uint64_t k_size = ggml_nbytes(k); | ||||
|     const uint64_t v_size = ggml_nbytes(v); | ||||
|     const uint64_t r_size = ggml_nbytes(r); | ||||
|     const uint64_t tf_size = ggml_nbytes(tf); | ||||
|     const uint64_t td_size = ggml_nbytes(td); | ||||
|     const uint64_t state_size = ggml_nbytes(state); | ||||
|     const uint64_t dst_size = ggml_nbytes(dst); | ||||
|  | ||||
|     std::array<uint32_t, 3> elements = { | ||||
|         (uint32_t)(pc.B * pc.H), | ||||
|         1, | ||||
|         1 | ||||
|     }; | ||||
|  | ||||
|     ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { | ||||
|         vk_subbuffer{ d_K, k_offset, k_size }, | ||||
|         vk_subbuffer{ d_V, v_offset, v_size }, | ||||
|         vk_subbuffer{ d_R, r_offset, r_size }, | ||||
|         vk_subbuffer{ d_TF, tf_offset, tf_size }, | ||||
|         vk_subbuffer{ d_TD, td_offset, td_size }, | ||||
|         vk_subbuffer{ d_State, state_offset, state_size }, | ||||
|         vk_subbuffer{ d_D, dst_offset, dst_size } | ||||
|     }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { | ||||
|     const size_t seq_length = dst->src[0]->ne[3]; | ||||
|     const size_t n_embed = dst->ne[0]; | ||||
|     const size_t n_heads = dst->src[0]->ne[2]; | ||||
|     const size_t n_seqs = dst->src[5]->ne[1]; | ||||
|  | ||||
|     ggml_vk_op_f32_rwkv6( | ||||
|         ctx, subctx, dst, | ||||
|         { | ||||
|             (uint32_t)n_seqs, | ||||
|             (uint32_t)seq_length, | ||||
|             (uint32_t)n_embed, | ||||
|             (uint32_t)n_heads, | ||||
|         }, | ||||
|         dryrun | ||||
|     ); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { | ||||
|     int * op_params = (int *)dst->op_params; | ||||
|  | ||||
| @@ -6569,6 +6712,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | ||||
|     case GGML_OP_IM2COL: | ||||
|     case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|     case GGML_OP_POOL_2D: | ||||
|     case GGML_OP_RWKV_WKV6: | ||||
|     case GGML_OP_LEAKY_RELU: | ||||
|     case GGML_OP_FLASH_ATTN_EXT: | ||||
|         break; | ||||
| @@ -6768,6 +6912,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | ||||
|     case GGML_OP_FLASH_ATTN_EXT: | ||||
|         ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|  | ||||
|     case GGML_OP_RWKV_WKV6: | ||||
|         ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|     default: | ||||
|         return false; | ||||
| @@ -6848,6 +6997,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * | ||||
|     case GGML_OP_IM2COL: | ||||
|     case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|     case GGML_OP_POOL_2D: | ||||
|     case GGML_OP_RWKV_WKV6: | ||||
|     case GGML_OP_LEAKY_RELU: | ||||
|     case GGML_OP_REPEAT: | ||||
|         buf = tensor->buffer; | ||||
| @@ -7724,6 +7874,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | ||||
|         case GGML_OP_IM2COL: | ||||
|         case GGML_OP_TIMESTEP_EMBEDDING: | ||||
|         case GGML_OP_POOL_2D: | ||||
|         case GGML_OP_RWKV_WKV6: | ||||
|         case GGML_OP_LEAKY_RELU: | ||||
|             return true; | ||||
|         default: | ||||
| @@ -8300,7 +8451,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { | ||||
|     } else if (tensor->op == GGML_OP_LEAKY_RELU) { | ||||
|         const float * op_params = (const float *)tensor->op_params; | ||||
|         tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); | ||||
|     } else { | ||||
|     } else if (tensor->op == GGML_OP_RWKV_WKV6) { | ||||
|         tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], | ||||
|         tensor->src[4], tensor->src[5]); | ||||
|     } | ||||
|     else { | ||||
|         std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; | ||||
|         GGML_ABORT("fatal error"); | ||||
|     } | ||||
|   | ||||
| @@ -479,6 +479,8 @@ void process_shaders() { | ||||
|  | ||||
|     string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); | ||||
|  | ||||
|     string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); | ||||
|  | ||||
|     for (auto &c : compiles) { | ||||
|         c.wait(); | ||||
|     } | ||||
|   | ||||
							
								
								
									
										87
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | ||||
| #version 450 | ||||
|  | ||||
| #extension GL_EXT_control_flow_attributes : require | ||||
|  | ||||
| #define BLOCK_SIZE 64 | ||||
| layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| layout(push_constant) uniform Parameters { | ||||
|     uint B; | ||||
|     uint T; | ||||
|     uint C; | ||||
|     uint H; | ||||
| }; | ||||
|  | ||||
| layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; | ||||
| layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; | ||||
| layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; | ||||
| layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; | ||||
| layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; | ||||
| layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; | ||||
| layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; | ||||
|  | ||||
| shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; | ||||
|  | ||||
| void main() { | ||||
|     const uint head_size = BLOCK_SIZE; | ||||
|     const uint batch_id = gl_WorkGroupID.x / H; | ||||
|     const uint head_id = gl_WorkGroupID.x % H; | ||||
|     const uint tid = gl_LocalInvocationID.x; | ||||
|  | ||||
|     const uint state_size = C * head_size; | ||||
|     const uint n_seq_tokens = T / B; | ||||
|  | ||||
|     if (batch_id >= B || head_id >= H) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     A_TYPE state[BLOCK_SIZE]; | ||||
|     [[unroll]] for (uint i = 0; i < head_size; i++) { | ||||
|         state[i] = state_in[batch_id * state_size + head_id * head_size * head_size | ||||
|                           + i * head_size + tid]; | ||||
|     } | ||||
|  | ||||
|     barrier(); | ||||
|     _tf[tid] = tf[head_id * head_size + tid]; | ||||
|     barrier(); | ||||
|  | ||||
|     const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; | ||||
|     const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; | ||||
|  | ||||
|     for (uint t = start_t; t < end_t; t += C) { | ||||
|         barrier(); | ||||
|         _k[tid] = k[t]; | ||||
|         _r[tid] = r[t]; | ||||
|         _td[tid] = td[t]; | ||||
|         barrier(); | ||||
|  | ||||
|         const A_TYPE v_val = v[t]; | ||||
|         A_TYPE y = 0.0; | ||||
|  | ||||
|         [[unroll]] for (uint j = 0; j < head_size; j += 4) { | ||||
|             vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); | ||||
|             vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); | ||||
|             vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); | ||||
|             vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); | ||||
|             vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); | ||||
|  | ||||
|             vec4 kv = k_vec * v_val; | ||||
|  | ||||
|             vec4 temp = tf_vec * kv + s_vec; | ||||
|             y += dot(r_vec, temp); | ||||
|  | ||||
|             s_vec = s_vec * td_vec + kv; | ||||
|             state[j] = s_vec.x; | ||||
|             state[j+1] = s_vec.y; | ||||
|             state[j+2] = s_vec.z; | ||||
|             state[j+3] = s_vec.w; | ||||
|         } | ||||
|  | ||||
|         dst[t] = y; | ||||
|     } | ||||
|  | ||||
|     [[unroll]] for (uint i = 0; i < head_size; i++) { | ||||
|         dst[T * C + batch_id * state_size + head_id * head_size * head_size | ||||
|             + i * head_size + tid] = state[i]; | ||||
|     } | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Zhiyuan Li
					Zhiyuan Li