mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	cuda : add set rows for bf16 (#14664)
This commit is contained in:
		| @@ -3226,8 +3226,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | ||||
|             } break; | ||||
|         case GGML_OP_SET_ROWS: | ||||
|             { | ||||
| #pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") | ||||
|                 return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && | ||||
| #pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") | ||||
|                 return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) && | ||||
|                        op->src[0]->type == GGML_TYPE_F32 && | ||||
|                        op->src[1]->type == GGML_TYPE_I64; | ||||
|             } break; | ||||
|   | ||||
| @@ -10,6 +10,11 @@ __device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, hal | ||||
|     *dst_h = __float2half(*src_f); | ||||
| } | ||||
|  | ||||
| template<> | ||||
| __device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) { | ||||
|     *dst_b = *src_f; | ||||
| } | ||||
|  | ||||
| template<> | ||||
| __device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) { | ||||
|     *dst_f = *src_f; | ||||
| @@ -124,6 +129,16 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||
|             nb1, nb2, nb3, | ||||
|             stream | ||||
|         ); | ||||
|     } else if (dst->type == GGML_TYPE_BF16) { | ||||
|         set_rows_cuda( | ||||
|             src0_d, src1_d, (nv_bfloat16*)dst->data, | ||||
|             ne00, ne01, ne02, ne03, | ||||
|             ne10, ne11, ne12, ne13, | ||||
|             nb01, nb02, nb03, | ||||
|             nb10, nb11, nb12, | ||||
|             nb1, nb2, nb3, | ||||
|             stream | ||||
|         ); | ||||
|     } else { | ||||
|         GGML_ABORT("unsupported type"); | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Sigbjørn Skjæret
					Sigbjørn Skjæret