mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	cuda : add RoPE kernel for mode == 2 (NeoX) (#2760)
* cuda : add RoPE kernel for mode == 2 (NeoX) * falcon : do not offload the embeddings layer
This commit is contained in:
		
							
								
								
									
										58
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										58
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -3907,28 +3907,27 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c | ||||
|     dst[i + 1] = x0*sin_theta + x1*cos_theta; | ||||
| } | ||||
|  | ||||
| // TODO: this implementation is wrong! | ||||
| //static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0, | ||||
| //                                const float p_delta, const int p_delta_rows, const float theta_scale) { | ||||
| //    const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); | ||||
| // | ||||
| //    if (col >= ncols) { | ||||
| //        return; | ||||
| //    } | ||||
| // | ||||
| //    const int row = blockDim.x*blockIdx.x + threadIdx.x; | ||||
| //    const int i = row*ncols + col/2; | ||||
| // | ||||
| //    const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); | ||||
| //    const float sin_theta = sinf(theta); | ||||
| //    const float cos_theta = cosf(theta); | ||||
| // | ||||
| //    const float x0 = x[i + 0]; | ||||
| //    const float x1 = x[i + ncols/2]; | ||||
| // | ||||
| //    dst[i + 0]       = x0*cos_theta - x1*sin_theta; | ||||
| //    dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; | ||||
| //} | ||||
| static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0, | ||||
|                                 const float p_delta, const int p_delta_rows, const float theta_scale) { | ||||
|     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); | ||||
|  | ||||
|     if (col >= ncols) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     const int row = blockDim.x*blockIdx.x + threadIdx.x; | ||||
|     const int i = row*ncols + col/2; | ||||
|  | ||||
|     const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); | ||||
|     const float sin_theta = sinf(theta); | ||||
|     const float cos_theta = cosf(theta); | ||||
|  | ||||
|     const float x0 = x[i + 0]; | ||||
|     const float x1 = x[i + ncols/2]; | ||||
|  | ||||
|     dst[i + 0]       = x0*cos_theta - x1*sin_theta; | ||||
|     dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; | ||||
| } | ||||
|  | ||||
| static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) { | ||||
|     const int col = blockDim.x*blockIdx.x + threadIdx.x; | ||||
| @@ -4799,13 +4798,21 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons | ||||
|  | ||||
| static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, | ||||
|                           const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { | ||||
|     GGML_ASSERT(nrows % 2 == 0); | ||||
|     GGML_ASSERT(nrows % 2 == 0); // GG: is this assert really needed? I don't see why | ||||
|     const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1); | ||||
|     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); | ||||
|     const dim3 block_nums(nrows, num_blocks_x, 1); | ||||
|     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); | ||||
| } | ||||
|  | ||||
| static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, | ||||
|                           const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { | ||||
|     const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1); | ||||
|     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); | ||||
|     const dim3 block_nums(nrows, num_blocks_x, 1); | ||||
|     rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); | ||||
| } | ||||
|  | ||||
| static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) { | ||||
|     GGML_ASSERT(nrows % 4 == 0); | ||||
|     const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1); | ||||
| @@ -5548,8 +5555,9 @@ inline void ggml_cuda_op_rope( | ||||
|         const float block_p = max(p - (n_ctx - 2.f), 0.f); | ||||
|         rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); | ||||
|     } else if (is_neox) { | ||||
|         GGML_ASSERT(false && "RoPE NeoX not implemented yet"); | ||||
| #pragma message("TODO: implement RoPE NeoX for CUDA") | ||||
|         GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); | ||||
|         const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; | ||||
|         rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); | ||||
|     } else { | ||||
|         const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; | ||||
|         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); | ||||
|   | ||||
							
								
								
									
										22
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1958,6 +1958,14 @@ static void llm_load_tensors( | ||||
|                         model.output_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm); | ||||
|                         model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd},          backend_norm); | ||||
|                         model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output); | ||||
|  | ||||
|                         if (backend_norm == GGML_BACKEND_GPU) { | ||||
|                             vram_weights += ggml_nbytes(model.output_norm); | ||||
|                             vram_weights += ggml_nbytes(model.output_norm_b); | ||||
|                         } | ||||
|                         if (backend_output == GGML_BACKEND_GPU_SPLIT) { | ||||
|                             vram_weights += ggml_nbytes(model.output); | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     const uint32_t n_ff = hparams.n_ff; | ||||
| @@ -1967,7 +1975,7 @@ static void llm_load_tensors( | ||||
|                     model.layers.resize(n_layer); | ||||
|  | ||||
|                     for (uint32_t i = 0; i < n_layer; ++i) { | ||||
|                         const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT | ||||
|                         const ggml_backend backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT | ||||
|                         const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT | ||||
|  | ||||
|                         auto & layer = model.layers[i]; | ||||
| @@ -1978,6 +1986,11 @@ static void llm_load_tensors( | ||||
|                         if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) { | ||||
|                             layer.attn_norm_2   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend); | ||||
|                             layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, backend); | ||||
|  | ||||
|                             if (backend == GGML_BACKEND_GPU) { | ||||
|                                 vram_weights += ggml_nbytes(layer.attn_norm_2); | ||||
|                                 vram_weights += ggml_nbytes(layer.attn_norm_2_b); | ||||
|                             } | ||||
|                         } | ||||
|  | ||||
|                         layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); | ||||
| @@ -1985,6 +1998,13 @@ static void llm_load_tensors( | ||||
|  | ||||
|                         layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split); | ||||
|                         layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split); | ||||
|  | ||||
|                         if (backend == GGML_BACKEND_GPU) { | ||||
|                             vram_weights += | ||||
|                                 ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + | ||||
|                                 ggml_nbytes(layer.wqkv)      + ggml_nbytes(layer.wo)          + | ||||
|                                 ggml_nbytes(layer.w2)        + ggml_nbytes(layer.w3); | ||||
|                         } | ||||
|                     } | ||||
|                 } break; | ||||
|             default: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov