mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Update Vulkan RoPE implementation (#7818)
* Update Vulkan RoPE implementation * Return nullptr on alloc_buffer when allocation fails, instead of throwing an exception Minor fixes * Fix segfault when running out of VRAM Co-authored-by: slaren <slarengh@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
		| @@ -886,7 +886,7 @@ static bool alloc_tensor_range(struct ggml_context * ctx, | |||||||
|         fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size); |         fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size); | ||||||
| #endif | #endif | ||||||
|         for (size_t i = 0; i < *n_buffers; i++) { |         for (size_t i = 0; i < *n_buffers; i++) { | ||||||
|             ggml_backend_buffer_free(*buffers[i]); |             ggml_backend_buffer_free((*buffers)[i]); | ||||||
|         } |         } | ||||||
|         free(*buffers); |         free(*buffers); | ||||||
|         return false; |         return false; | ||||||
|   | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -150,7 +150,7 @@ struct vk_device { | |||||||
|     vk_pipeline pipeline_relu_f32; |     vk_pipeline pipeline_relu_f32; | ||||||
|     vk_pipeline pipeline_diag_mask_inf_f32; |     vk_pipeline pipeline_diag_mask_inf_f32; | ||||||
|     vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; |     vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; | ||||||
|     vk_pipeline pipeline_rope_f32, pipeline_rope_f16; |     vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; | ||||||
|     vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; |     vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; | ||||||
|     vk_pipeline pipeline_argsort_f32; |     vk_pipeline pipeline_argsort_f32; | ||||||
|     vk_pipeline pipeline_sum_rows_f32; |     vk_pipeline pipeline_sum_rows_f32; | ||||||
| @@ -283,26 +283,15 @@ struct vk_op_diag_mask_push_constants { | |||||||
|  |  | ||||||
| struct vk_op_rope_push_constants { | struct vk_op_rope_push_constants { | ||||||
|     uint32_t ncols; |     uint32_t ncols; | ||||||
|  |     uint32_t n_dims; | ||||||
|     float freq_scale; |     float freq_scale; | ||||||
|     uint32_t p_delta_rows; |     uint32_t p_delta_rows; | ||||||
|     float freq_base; |     float freq_base; | ||||||
|     float ext_factor; |     float ext_factor; | ||||||
|     float attn_factor; |     float attn_factor; | ||||||
|     float corr_dims[4]; |     float corr_dims[2]; | ||||||
| }; |  | ||||||
|  |  | ||||||
| struct vk_op_rope_neox_push_constants { |  | ||||||
|     uint32_t ncols; |  | ||||||
|     uint32_t ndims; |  | ||||||
|     float freq_scale; |  | ||||||
|     uint32_t p_delta_rows; |  | ||||||
|     float freq_base; |  | ||||||
|     float ext_factor; |  | ||||||
|     float attn_factor; |  | ||||||
|     float corr_dims[4]; |  | ||||||
|     float theta_scale; |     float theta_scale; | ||||||
|     float inv_ndims; |     uint32_t has_ff; | ||||||
|     uint32_t has_freq_facs; |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct vk_op_soft_max_push_constants { | struct vk_op_soft_max_push_constants { | ||||||
| @@ -1534,11 +1523,11 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) { | |||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); | ||||||
|  |  | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); | ||||||
|  |  | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_neox_push_constants), {1, 512, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); | ||||||
|  |  | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1); | ||||||
|  |  | ||||||
| @@ -3905,10 +3894,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const | |||||||
|                 } |                 } | ||||||
|             } else { |             } else { | ||||||
|                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |                 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||||||
|                     return ctx->device->pipeline_rope_f32; |                     return ctx->device->pipeline_rope_norm_f32; | ||||||
|                 } |                 } | ||||||
|                 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { |                 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { | ||||||
|                     return ctx->device->pipeline_rope_f16; |                     return ctx->device->pipeline_rope_norm_f16; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             return nullptr; |             return nullptr; | ||||||
| @@ -4152,10 +4141,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
|             ggml_vk_sync_buffers(subctx); |             ggml_vk_sync_buffers(subctx); | ||||||
|             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); |             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); | ||||||
|         } else if (op == GGML_OP_ROPE) { |         } else if (op == GGML_OP_ROPE) { | ||||||
|             const int mode          = ((int32_t *) dst->op_params)[2]; |  | ||||||
|             const bool is_neox = mode & 2; |  | ||||||
|  |  | ||||||
|             if (is_neox) { |  | ||||||
|             // Empty src2 is possible in rope, but the shader needs a buffer |             // Empty src2 is possible in rope, but the shader needs a buffer | ||||||
|             vk_subbuffer subbuf_z; |             vk_subbuffer subbuf_z; | ||||||
|             if (use_src2) { |             if (use_src2) { | ||||||
| @@ -4166,10 +4151,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
|  |  | ||||||
|             ggml_vk_sync_buffers(subctx); |             ggml_vk_sync_buffers(subctx); | ||||||
|             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); |             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); | ||||||
|             } else { |  | ||||||
|                 ggml_vk_sync_buffers(subctx); |  | ||||||
|                 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); |  | ||||||
|             } |  | ||||||
|         } else if (use_src2) { |         } else if (use_src2) { | ||||||
|             ggml_vk_sync_buffers(subctx); |             ggml_vk_sync_buffers(subctx); | ||||||
|             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_Z, z_buf_offset, z_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); |             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_Z, z_buf_offset, z_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); | ||||||
| @@ -4391,7 +4372,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, | |||||||
|  |  | ||||||
| static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { | static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { | ||||||
|     const int n_dims        = ((int32_t *) dst->op_params)[1]; |     const int n_dims        = ((int32_t *) dst->op_params)[1]; | ||||||
|     const int mode          = ((int32_t *) dst->op_params)[2]; |     // const int mode          = ((int32_t *) dst->op_params)[2]; | ||||||
|     // const int n_ctx         = ((int32_t *) dst->op_params)[3]; |     // const int n_ctx         = ((int32_t *) dst->op_params)[3]; | ||||||
|     const int n_ctx_orig    = ((int32_t *) dst->op_params)[4]; |     const int n_ctx_orig    = ((int32_t *) dst->op_params)[4]; | ||||||
|     const float freq_base   = ((float *)   dst->op_params)[5]; |     const float freq_base   = ((float *)   dst->op_params)[5]; | ||||||
| @@ -4401,28 +4382,16 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con | |||||||
|     const float beta_fast   = ((float *)   dst->op_params)[9]; |     const float beta_fast   = ((float *)   dst->op_params)[9]; | ||||||
|     const float beta_slow   = ((float *)   dst->op_params)[10]; |     const float beta_slow   = ((float *)   dst->op_params)[10]; | ||||||
|  |  | ||||||
|     const bool is_neox = mode & 2; |  | ||||||
|  |  | ||||||
| #pragma message("TODO: update rope NORM mode to match NEOX mode") |  | ||||||
| #pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634") |  | ||||||
|  |  | ||||||
|     float corr_dims[2]; |     float corr_dims[2]; | ||||||
|     ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); |     ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); | ||||||
|  |  | ||||||
|     if (is_neox) { |  | ||||||
|     const float theta_scale = powf(freq_base, -2.0f/n_dims); |     const float theta_scale = powf(freq_base, -2.0f/n_dims); | ||||||
|         const float inv_ndims = -1.0f / n_dims; |  | ||||||
|         ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { |     ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { | ||||||
|         (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], |         (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], | ||||||
|             freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f}, theta_scale, inv_ndims, |         freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, | ||||||
|         src2 != nullptr, |         src2 != nullptr, | ||||||
|     }); |     }); | ||||||
|     } else { |  | ||||||
|         ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { |  | ||||||
|             (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], |  | ||||||
|             freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f} |  | ||||||
|         }); |  | ||||||
|     } |  | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | ||||||
| @@ -6070,7 +6039,13 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer( | |||||||
|     std::cerr << "ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")" << std::endl; |     std::cerr << "ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")" << std::endl; | ||||||
| #endif | #endif | ||||||
|     ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; |     ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; | ||||||
|     vk_buffer dev_buffer = ggml_vk_create_buffer_device(ctx->ctx, size); |  | ||||||
|  |     vk_buffer dev_buffer = nullptr; | ||||||
|  |     try { | ||||||
|  |         dev_buffer = ggml_vk_create_buffer_device(ctx->ctx, size); | ||||||
|  |     } catch (const vk::SystemError& e) { | ||||||
|  |         return nullptr; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->ctx, std::move(dev_buffer), ctx->name); |     ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->ctx, std::move(dev_buffer), ctx->name); | ||||||
|  |  | ||||||
| @@ -6466,7 +6441,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const | |||||||
|         //         return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; |         //         return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; | ||||||
|         //     } break; |         //     } break; | ||||||
|         case GGML_OP_ROPE: |         case GGML_OP_ROPE: | ||||||
|             return true; |             return ggml_is_contiguous(op->src[0]); | ||||||
|         case GGML_OP_NONE: |         case GGML_OP_NONE: | ||||||
|         case GGML_OP_RESHAPE: |         case GGML_OP_RESHAPE: | ||||||
|         case GGML_OP_VIEW: |         case GGML_OP_VIEW: | ||||||
|   | |||||||
| @@ -2400,7 +2400,7 @@ void main() { | |||||||
| """ | """ | ||||||
|  |  | ||||||
| # ROPE | # ROPE | ||||||
| rope_src = """ | rope_norm_src = """ | ||||||
| #version 450 | #version 450 | ||||||
|  |  | ||||||
| #extension GL_EXT_shader_16bit_storage : require | #extension GL_EXT_shader_16bit_storage : require | ||||||
| @@ -2408,17 +2408,21 @@ rope_src = """ | |||||||
| layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; | layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; | ||||||
|  |  | ||||||
| layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; | layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; | ||||||
| layout (binding = 1) readonly buffer Y {int data_b[];}; | layout (binding = 1) readonly buffer Y {int data_pos[];}; | ||||||
| layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; | layout (binding = 2) readonly buffer Z {float data_ff[];}; | ||||||
|  | layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; | ||||||
|  |  | ||||||
| layout (push_constant) uniform parameter { | layout (push_constant) uniform parameter { | ||||||
|     uint ncols; |     uint ncols; | ||||||
|  |     uint n_dims; | ||||||
|     float freq_scale; |     float freq_scale; | ||||||
|     uint p_delta_rows; |     uint p_delta_rows; | ||||||
|     float freq_base; |     float freq_base; | ||||||
|     float ext_factor; |     float ext_factor; | ||||||
|     float attn_factor; |     float attn_factor; | ||||||
|     float corr_dims[4]; |     float corr_dims[2]; | ||||||
|  |     float theta_scale; | ||||||
|  |     uint has_ff; | ||||||
| } p; | } p; | ||||||
|  |  | ||||||
| float rope_yarn_ramp(const float low, const float high, const uint i0) { | float rope_yarn_ramp(const float low, const float high, const uint i0) { | ||||||
| @@ -2450,14 +2454,24 @@ void main() { | |||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     if (col >= p.n_dims) { | ||||||
|  |         const uint i = row*p.ncols + col; | ||||||
|  |  | ||||||
|  |         data_d[i + 0] = data_a[i + 0]; | ||||||
|  |         data_d[i + 1] = data_a[i + 1]; | ||||||
|  |  | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     const uint i = row*p.ncols + col; |     const uint i = row*p.ncols + col; | ||||||
|     const uint i2 = row/p.p_delta_rows; |     const uint i2 = row/p.p_delta_rows; | ||||||
|  |  | ||||||
|     const int pos = data_b[i2]; |     const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); | ||||||
|     const float theta_base = pos * pow(p.freq_base, -float(col)/p.ncols); |  | ||||||
|  |     const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; | ||||||
|  |  | ||||||
|     float cos_theta, sin_theta; |     float cos_theta, sin_theta; | ||||||
|     rope_yarn(theta_base, col, cos_theta, sin_theta); |     rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); | ||||||
|  |  | ||||||
|     const float x0 = float(data_a[i + 0]); |     const float x0 = float(data_a[i + 0]); | ||||||
|     const float x1 = float(data_a[i + 1]); |     const float x1 = float(data_a[i + 1]); | ||||||
| @@ -2475,22 +2489,21 @@ rope_neox_src = """ | |||||||
| layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; | layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; | ||||||
|  |  | ||||||
| layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; | layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; | ||||||
| layout (binding = 1) readonly buffer Y {int data_b[];}; | layout (binding = 1) readonly buffer Y {int data_pos[];}; | ||||||
| layout (binding = 2) readonly buffer Z {float data_freq_factors[];}; | layout (binding = 2) readonly buffer Z {float data_ff[];}; | ||||||
| layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; | layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; | ||||||
|  |  | ||||||
| layout (push_constant) uniform parameter { | layout (push_constant) uniform parameter { | ||||||
|     uint ncols; |     uint ncols; | ||||||
|     uint ndims; |     uint n_dims; | ||||||
|     float freq_scale; |     float freq_scale; | ||||||
|     uint p_delta_rows; |     uint p_delta_rows; | ||||||
|     float freq_base; |     float freq_base; | ||||||
|     float ext_factor; |     float ext_factor; | ||||||
|     float attn_factor; |     float attn_factor; | ||||||
|     float corr_dims[4]; |     float corr_dims[2]; | ||||||
|     float theta_scale; |     float theta_scale; | ||||||
|     float inv_ndims; |     uint has_ff; | ||||||
|     uint has_freq_facs; |  | ||||||
| } p; | } p; | ||||||
|  |  | ||||||
| float rope_yarn_ramp(const float low, const float high, const uint i0) { | float rope_yarn_ramp(const float low, const float high, const uint i0) { | ||||||
| @@ -2522,11 +2535,8 @@ void main() { | |||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const uint ib = col / p.ndims; |     if (col >= p.n_dims) { | ||||||
|     const uint ic = col % p.ndims; |         const uint i = row*p.ncols + col; | ||||||
|  |  | ||||||
|     if (ib > 0) { |  | ||||||
|         const uint i = row*p.ncols + ib*p.ndims + ic; |  | ||||||
|  |  | ||||||
|         data_d[i + 0] = data_a[i + 0]; |         data_d[i + 0] = data_a[i + 0]; | ||||||
|         data_d[i + 1] = data_a[i + 1]; |         data_d[i + 1] = data_a[i + 1]; | ||||||
| @@ -2534,29 +2544,27 @@ void main() { | |||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const uint i  = row*p.ncols + ib*p.ndims + ic/2; |     const uint i  = row*p.ncols + col/2; | ||||||
|     const uint i2 = row/p.p_delta_rows; |     const uint i2 = row/p.p_delta_rows; | ||||||
|  |  | ||||||
|     const int pos = data_b[i2]; |     const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f); | ||||||
|     const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f; |  | ||||||
|     const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor; |     const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f; | ||||||
|  |  | ||||||
|     float cos_theta, sin_theta; |     float cos_theta, sin_theta; | ||||||
|     rope_yarn(theta_base, ic, cos_theta, sin_theta); |     rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta); | ||||||
|  |  | ||||||
|     const float x0 = float(data_a[i + 0]); |     const float x0 = float(data_a[i + 0]); | ||||||
|     const float x1 = float(data_a[i + p.ndims/2]); |     const float x1 = float(data_a[i + p.n_dims/2]); | ||||||
|  |  | ||||||
|     data_d[i + 0]        = D_TYPE(x0*cos_theta - x1*sin_theta); |     data_d[i + 0]        = D_TYPE(x0*cos_theta - x1*sin_theta); | ||||||
|     data_d[i + p.ndims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); |     data_d[i + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); | ||||||
| } | } | ||||||
| """ | """ | ||||||
|  |  | ||||||
| argsort_src = """ | argsort_src = """ | ||||||
| #version 450 | #version 450 | ||||||
|  |  | ||||||
| #extension GL_EXT_shader_16bit_storage : require |  | ||||||
|  |  | ||||||
| #define BLOCK_SIZE 1024 | #define BLOCK_SIZE 1024 | ||||||
| #define ASC 0 | #define ASC 0 | ||||||
|  |  | ||||||
| @@ -3039,8 +3047,8 @@ async def main(): | |||||||
|     tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"})) |     tasks.append(string_to_spv("soft_max_f32", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "C_TYPE": "float", "D_TYPE": "float"})) | ||||||
|     tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"})) |     tasks.append(string_to_spv("soft_max_f32_f16", f"{soft_max_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float16_t", "C_TYPE": "float16_t", "D_TYPE": "float"})) | ||||||
|  |  | ||||||
|     tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"})) |     tasks.append(string_to_spv("rope_norm_f32", rope_norm_src, {"A_TYPE": "float", "D_TYPE": "float"})) | ||||||
|     tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) |     tasks.append(string_to_spv("rope_norm_f16", rope_norm_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) | ||||||
|  |  | ||||||
|     tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"})) |     tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"})) | ||||||
|     tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) |     tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 0cc4m
					0cc4m