mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	llama: Add support for RWKV v7 architecture (#12412)
* ggml: Add op l2_norm Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * ggml: Add op rwkv_wkv7 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: Add support for RWKV7 and ARWKV7 models Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: fix inference with RWKV6Qwen2 Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: add more (a)rwkv7 variants in size Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * Apply code-format changes Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * fix MUSA build Signed-off-by: Molly Sophia <mollysophia379@gmail.com> * llama: fix shape error with rwkv using llama-parallel Signed-off-by: Molly Sophia <mollysophia379@gmail.com> --------- Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
		@@ -2696,6 +2696,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
 | 
			
		||||
    GGML_SYCL_DEBUG("call %s done\n", __func__);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 | 
			
		||||
    GGML_SYCL_DEBUG("call %s\n", __func__);
 | 
			
		||||
    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
 | 
			
		||||
    GGML_SYCL_DEBUG("call %s done\n", __func__);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 | 
			
		||||
    GGML_SYCL_DEBUG("call %s\n", __func__);
 | 
			
		||||
    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
 | 
			
		||||
@@ -3410,6 +3416,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
 | 
			
		||||
        case GGML_OP_RMS_NORM:
 | 
			
		||||
            ggml_sycl_rms_norm(ctx, dst);
 | 
			
		||||
            break;
 | 
			
		||||
        case GGML_OP_L2_NORM:
 | 
			
		||||
            ggml_sycl_l2_norm(ctx, dst);
 | 
			
		||||
            break;
 | 
			
		||||
        case GGML_OP_MUL_MAT:
 | 
			
		||||
            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
 | 
			
		||||
                return false;
 | 
			
		||||
@@ -3487,6 +3496,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
 | 
			
		||||
        case GGML_OP_RWKV_WKV6:
 | 
			
		||||
            ggml_sycl_op_rwkv_wkv6(ctx, dst);
 | 
			
		||||
            break;
 | 
			
		||||
        case GGML_OP_RWKV_WKV7:
 | 
			
		||||
            ggml_sycl_op_rwkv_wkv7(ctx, dst);
 | 
			
		||||
            break;
 | 
			
		||||
        case GGML_OP_GATED_LINEAR_ATTN:
 | 
			
		||||
            ggml_sycl_op_gated_linear_attn(ctx, dst);
 | 
			
		||||
            break;
 | 
			
		||||
@@ -4012,6 +4024,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
 | 
			
		||||
            return (op->src[0]->type == GGML_TYPE_F32);
 | 
			
		||||
        case GGML_OP_NORM:
 | 
			
		||||
        case GGML_OP_RMS_NORM:
 | 
			
		||||
        case GGML_OP_L2_NORM:
 | 
			
		||||
        case GGML_OP_GROUP_NORM:
 | 
			
		||||
            return ggml_is_contiguous(op->src[0]);
 | 
			
		||||
        case GGML_OP_SCALE:
 | 
			
		||||
@@ -4045,6 +4058,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
 | 
			
		||||
        case GGML_OP_LEAKY_RELU:
 | 
			
		||||
        case GGML_OP_TIMESTEP_EMBEDDING:
 | 
			
		||||
        case GGML_OP_RWKV_WKV6:
 | 
			
		||||
        case GGML_OP_RWKV_WKV7:
 | 
			
		||||
        case GGML_OP_GATED_LINEAR_ATTN:
 | 
			
		||||
            return true;
 | 
			
		||||
        default:
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user