mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	SYCL : SOFTMAX F16 mask support and other fixes (#11261)
Implemented ggml_sycl_op_soft_max() F16 src1(mask) support for which a pragma deprecation warning was added during #5021. To do this, had to decouple it from ggml_sycl_op_flatten which always considered src1 to be of fp32 type(many OP functions are dependent on it). * SYCL: SOFTMAX F16 mask support and other fixes * test-backend-ops: Add F16 mask test cases
This commit is contained in:
		@@ -3878,10 +3878,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
 | 
			
		||||
    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 | 
			
		||||
    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 | 
			
		||||
    GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
 | 
			
		||||
    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
 | 
			
		||||
@@ -4090,7 +4086,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
 | 
			
		||||
            ggml_sycl_diag_mask_inf(ctx, dst);
 | 
			
		||||
            break;
 | 
			
		||||
        case GGML_OP_SOFT_MAX:
 | 
			
		||||
            ggml_sycl_soft_max(ctx, dst);
 | 
			
		||||
            ggml_sycl_op_soft_max(ctx, dst);
 | 
			
		||||
            break;
 | 
			
		||||
        case GGML_OP_ROPE:
 | 
			
		||||
            ggml_sycl_rope(ctx, dst);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user