mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama: add support for QRWKV6 model architecture (#11001)
llama: add support for QRWKV6 model architecture (#11001) * WIP: Add support for RWKV6Qwen2 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * RWKV: Some graph simplification Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Add support for RWKV6Qwen2 with cpu and cuda GLA Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix some typos Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * code format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix wkv test & add gla test Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Fix cuda warning Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update README.md Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Update ggml/src/ggml-cuda/gla.cu Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Fix fused lerp weights loading with RWKV6 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * better sanity check skipping for QRWKV6 in llama-quant thanks @compilade Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: compilade <git@compilade.net> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
		| @@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos( | ||||
| static void ggml_compute_forward_rwkv_wkv6_f32( | ||||
|         const struct ggml_compute_params * params, | ||||
|         struct ggml_tensor * dst) { | ||||
|     const int64_t T = dst->src[1]->ne[3]; | ||||
|     const int64_t T = dst->src[1]->ne[2]; | ||||
|     const int64_t C = dst->ne[0]; | ||||
|     const int64_t HEADS = dst->src[1]->ne[2]; | ||||
|     const int64_t HEADS = dst->src[1]->ne[1]; | ||||
|     const int64_t n_seqs = dst->src[5]->ne[1]; | ||||
|     const int64_t head_size = C / HEADS; | ||||
|  | ||||
| @@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6( | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ggml_compute_forward_gla | ||||
|  | ||||
| static void ggml_compute_forward_gla_f32( | ||||
|         const struct ggml_compute_params * params, | ||||
|         struct ggml_tensor * dst) { | ||||
|     const int64_t T = dst->src[1]->ne[2]; | ||||
|     const int64_t C = dst->ne[0]; | ||||
|     const int64_t HEADS = dst->src[1]->ne[1]; | ||||
|     const int64_t n_seqs = dst->src[4]->ne[1]; | ||||
|     const int64_t head_size = C / HEADS; | ||||
|     const float scale = ggml_get_op_params_f32(dst, 0); | ||||
|  | ||||
|     float * dst_data = (float *) dst->data; | ||||
|     float * state = ((float *) dst->data) + C * T; | ||||
|  | ||||
|     const int ith = params->ith; | ||||
|     const int nth = params->nth; | ||||
|  | ||||
|     if (ith >= HEADS) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     const int h_start = (HEADS * ith) / nth; | ||||
|     const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? | ||||
|                 (HEADS * (ith + 1)) / nth : HEADS; | ||||
|  | ||||
|     float * k = (float *) dst->src[0]->data; | ||||
|     float * v = (float *) dst->src[1]->data; | ||||
|     float * q = (float *) dst->src[2]->data; | ||||
|     float * g = (float *) dst->src[3]->data; | ||||
|  | ||||
|     size_t t_stride = HEADS * head_size; // Same to C | ||||
|  | ||||
|     size_t h_stride = C / HEADS; | ||||
|     GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS | ||||
|     size_t h_stride_2d = head_size * head_size; | ||||
|  | ||||
|     if (ith == 0) { | ||||
|         memset(dst_data, 0, T * C * sizeof(float)); | ||||
|     } | ||||
|     ggml_barrier(params->threadpool); | ||||
|  | ||||
|  | ||||
|     #if defined(__AVX__) && !defined(__AVX512F__) | ||||
|         #define GGML_F32X GGML_F32x8 | ||||
|         #define GGML_F32X_SET1 GGML_F32x8_SET1 | ||||
|         #define GGML_F32X_LOAD GGML_F32x8_LOAD | ||||
|         #define GGML_F32X_STORE GGML_F32x8_STORE | ||||
|         #define GGML_F32X_MUL GGML_F32x8_MUL | ||||
|         #define GGML_F32X_FMA GGML_F32x8_FMA | ||||
|         #define GLA_VECTOR_SIZE 8 | ||||
|     #elif defined(__AVX512F__) | ||||
|         #define GGML_F32X GGML_F32x16 | ||||
|         #define GGML_F32X_SET1 GGML_F32x16_SET1 | ||||
|         #define GGML_F32X_LOAD GGML_F32x16_LOAD | ||||
|         #define GGML_F32X_STORE GGML_F32x16_STORE | ||||
|         #define GGML_F32X_MUL GGML_F32x16_MUL | ||||
|         #define GGML_F32X_FMA GGML_F32x16_FMA | ||||
|         #define GLA_VECTOR_SIZE 16 | ||||
|     #elif defined(__ARM_NEON) && defined(__aarch64__) | ||||
|         #define GGML_F32X GGML_F32x4 | ||||
|         #define GGML_F32X_SET1 GGML_F32x4_SET1 | ||||
|         #define GGML_F32X_LOAD GGML_F32x4_LOAD | ||||
|         #define GGML_F32X_STORE GGML_F32x4_STORE | ||||
|         #define GGML_F32X_MUL GGML_F32x4_MUL | ||||
|         #define GGML_F32X_FMA GGML_F32x4_FMA | ||||
|         #define GLA_VECTOR_SIZE 4 | ||||
|     #endif | ||||
|  | ||||
|     #ifdef GLA_VECTOR_SIZE | ||||
|         const int64_t vec_count = head_size / GLA_VECTOR_SIZE; | ||||
|  | ||||
|         for (int64_t t = 0; t < T; t++) { | ||||
|             size_t t_offset = t * t_stride; | ||||
|             size_t state_offset = head_size * C * (t / (T / n_seqs)); | ||||
|             float * state_cur = state + state_offset; | ||||
|             float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset; | ||||
|  | ||||
|             for (int64_t h = h_start; h < h_end; h++) { | ||||
|                 size_t h_offset = h * h_stride; | ||||
|                 size_t t_h_offset = t_offset + h_offset; | ||||
|                 size_t h_2d_offset = h * h_stride_2d; | ||||
|  | ||||
|                 for (int64_t i = 0; i < head_size; i++) { | ||||
|                     size_t t_h_i_offset = t_h_offset + i; | ||||
|                     size_t h_2d_i_offset = h_2d_offset + i * h_stride; | ||||
|  | ||||
|                     float k_val = k[t_h_i_offset]; | ||||
|                     float q_val = q[t_h_i_offset] * scale; | ||||
|                     float g_val = g[t_h_i_offset]; | ||||
|  | ||||
|                     // Broadcast scalar values to vectors | ||||
|                     GGML_F32X k_vec = GGML_F32X_SET1(k_val); | ||||
|                     GGML_F32X q_vec = GGML_F32X_SET1(q_val); | ||||
|                     GGML_F32X g_vec = GGML_F32X_SET1(g_val); | ||||
|  | ||||
|                     for (int64_t j = 0; j < vec_count; j++) { | ||||
|                         size_t base_j = j * GLA_VECTOR_SIZE; | ||||
|                         size_t t_h_j_offset = t_h_offset + base_j; | ||||
|                         size_t h_2d_i_j_offset = h_2d_i_offset + base_j; | ||||
|  | ||||
|                         // Load x elements at once | ||||
|                         GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]); | ||||
|                         GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]); | ||||
|                         GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]); | ||||
|  | ||||
|                         // Compute kv = v * k | ||||
|                         GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec); | ||||
|  | ||||
|                         // Compute temp = prev_state * g + kv | ||||
|                         GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec); | ||||
|  | ||||
|                         // Update dst: dst += temp * q | ||||
|                         dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec); | ||||
|                         GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec); | ||||
|  | ||||
|                         // Update state | ||||
|                         GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec); | ||||
|                     } | ||||
|  | ||||
|                     // Handle remaining elements, this will not be used. | ||||
|                     for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) { | ||||
|                         size_t t_h_j_offset = t_h_offset + j; | ||||
|                         size_t h_2d_i_j_offset = h_2d_i_offset + j; | ||||
|                         float v_val = v[t_h_j_offset]; | ||||
|                         float kv_val = v_val * k_val; | ||||
|                         float prev_state_val = state_prev[h_2d_i_j_offset]; | ||||
|                         float temp_val = kv_val + prev_state_val * g_val; | ||||
|                         dst_data[t_h_j_offset] += temp_val * q_val; | ||||
|                         state_cur[h_2d_i_j_offset] = temp_val; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|     #else | ||||
|         for (int64_t t = 0; t < T; t++) { | ||||
|             size_t t_offset = t * t_stride; | ||||
|             size_t state_offset = head_size * C * (t / (T / n_seqs)); | ||||
|             float * state_cur = state + state_offset; | ||||
|             float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset; | ||||
|  | ||||
|             for (int64_t h = h_start; h < h_end; h++) { | ||||
|                 size_t h_offset = h * h_stride; | ||||
|                 size_t t_h_offset = t_offset + h_offset; | ||||
|                 size_t h_2d_offset = h * h_stride_2d; | ||||
|  | ||||
|                 for (int64_t i = 0; i < head_size; i++) { | ||||
|                     size_t t_h_i_offset = t_h_offset + i; | ||||
|                     size_t h_2d_i_offset = h_2d_offset + i * h_stride; | ||||
|  | ||||
|                     float k_val = k[t_h_i_offset]; | ||||
|                     float q_val = q[t_h_i_offset] * scale; | ||||
|                     float g_val = g[t_h_i_offset]; | ||||
|  | ||||
|                     for (int64_t j = 0; j < head_size; j++) { | ||||
|                         size_t t_h_j_offset = t_h_offset + j; | ||||
|                         size_t h_2d_i_j_offset = h_2d_i_offset + j; | ||||
|  | ||||
|                         float v_val = v[t_h_j_offset]; | ||||
|                         float kv_val = v_val * k_val; | ||||
|                         float prev_state_val = state_prev[h_2d_i_j_offset]; | ||||
|                         float temp_val = prev_state_val * g_val + kv_val; | ||||
|                         dst_data[t_h_j_offset] += temp_val * q_val; | ||||
|                         state_cur[h_2d_i_j_offset] = temp_val; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     #endif | ||||
| } | ||||
|  | ||||
|  | ||||
| static void ggml_compute_forward_gla( | ||||
|         const struct ggml_compute_params * params, | ||||
|         struct ggml_tensor * dst) { | ||||
|  | ||||
|     const struct ggml_tensor * src0 = dst->src[0]; | ||||
|  | ||||
|     switch (src0->type) { | ||||
|         case GGML_TYPE_F32: | ||||
|             { | ||||
|                 ggml_compute_forward_gla_f32(params, dst); | ||||
|             } break; | ||||
|         default: | ||||
|             { | ||||
|                 GGML_ABORT("fatal error"); | ||||
|             } | ||||
|     } | ||||
| } | ||||
|  | ||||
| // ggml_compute_forward_map_unary | ||||
|  | ||||
| static void ggml_compute_forward_map_unary_f32( | ||||
| @@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm | ||||
|             { | ||||
|                 ggml_compute_forward_rwkv_wkv6(params, tensor); | ||||
|             } break; | ||||
|         case GGML_OP_GATED_LINEAR_ATTN: | ||||
|             { | ||||
|                 ggml_compute_forward_gla(params, tensor); | ||||
|             } break; | ||||
|         case GGML_OP_MAP_UNARY: | ||||
|             { | ||||
|                 ggml_unary_op_f32_t fun; | ||||
| @@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { | ||||
|         case GGML_OP_WIN_UNPART: | ||||
|         case GGML_OP_GET_REL_POS: | ||||
|         case GGML_OP_RWKV_WKV6: | ||||
|         case GGML_OP_GATED_LINEAR_ATTN: | ||||
|         case GGML_OP_MAP_UNARY: | ||||
|         case GGML_OP_MAP_BINARY: | ||||
|         case GGML_OP_MAP_CUSTOM1_F32: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Molly Sophia
					Molly Sophia