mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
@@ -1361,7 +1361,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
|
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
|
||||||
@@ -1521,6 +1520,9 @@ static id<MTLComputePipelineState> ggml_metal_compile_kernel(ggml_backend_t back
|
|||||||
NSString * key = [NSString stringWithUTF8String:name];
|
NSString * key = [NSString stringWithUTF8String:name];
|
||||||
[ctx->kernels_ext setObject:obj forKey:key];
|
[ctx->kernels_ext setObject:obj forKey:key];
|
||||||
|
|
||||||
|
[metal_function release];
|
||||||
|
[obj release];
|
||||||
|
|
||||||
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) kernel.pipeline,
|
GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) kernel.pipeline,
|
||||||
(int) kernel.pipeline.maxTotalThreadsPerThreadgroup,
|
(int) kernel.pipeline.maxTotalThreadsPerThreadgroup,
|
||||||
(int) kernel.pipeline.threadExecutionWidth);
|
(int) kernel.pipeline.threadExecutionWidth);
|
||||||
@@ -1542,8 +1544,6 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
|
|||||||
char name[256];
|
char name[256];
|
||||||
|
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
|
|
||||||
|
|
||||||
const int32_t dk = (int32_t) op->src[1]->ne[0];
|
const int32_t dk = (int32_t) op->src[1]->ne[0];
|
||||||
const int32_t dv = (int32_t) op->src[2]->ne[0];
|
const int32_t dv = (int32_t) op->src[2]->ne[0];
|
||||||
|
|
||||||
@@ -1575,7 +1575,7 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
cv = [[MTLFunctionConstantValues alloc] init];
|
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
|
||||||
|
|
||||||
[cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 0];
|
[cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 0];
|
||||||
[cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 1];
|
[cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT + 1];
|
||||||
@@ -1586,7 +1586,11 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
|
|||||||
[cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 21];
|
[cv setConstantValue:&ns20 type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 21];
|
||||||
[cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 22];
|
[cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT + 22];
|
||||||
|
|
||||||
return ggml_metal_compile_kernel(backend, base, name, cv);
|
res = ggml_metal_compile_kernel(backend, base, name, cv);
|
||||||
|
|
||||||
|
[cv release];
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1604,8 +1608,6 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec(
|
|||||||
char name[256];
|
char name[256];
|
||||||
|
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
|
|
||||||
|
|
||||||
const int32_t dk = (int32_t) op->src[1]->ne[0];
|
const int32_t dk = (int32_t) op->src[1]->ne[0];
|
||||||
const int32_t dv = (int32_t) op->src[2]->ne[0];
|
const int32_t dv = (int32_t) op->src[2]->ne[0];
|
||||||
|
|
||||||
@@ -1637,7 +1639,7 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
cv = [[MTLFunctionConstantValues alloc] init];
|
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
|
||||||
|
|
||||||
[cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 0];
|
[cv setConstantValue:&has_mask type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 0];
|
||||||
[cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 1];
|
[cv setConstantValue:&has_sinks type:MTLDataTypeBool atIndex:FC_FLASH_ATTN_EXT_VEC + 1];
|
||||||
@@ -1649,7 +1651,11 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec(
|
|||||||
[cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 22];
|
[cv setConstantValue:&nsg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 22];
|
||||||
[cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 23];
|
[cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC + 23];
|
||||||
|
|
||||||
return ggml_metal_compile_kernel(backend, base, name, cv);
|
res = ggml_metal_compile_kernel(backend, base, name, cv);
|
||||||
|
|
||||||
|
[cv release];
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1663,8 +1669,6 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec_re
|
|||||||
char name[256];
|
char name[256];
|
||||||
|
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
|
|
||||||
|
|
||||||
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
|
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
|
||||||
snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
|
snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
|
||||||
|
|
||||||
@@ -1674,12 +1678,16 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec_re
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
cv = [[MTLFunctionConstantValues alloc] init];
|
MTLFunctionConstantValues * cv = [[MTLFunctionConstantValues alloc] init];
|
||||||
|
|
||||||
[cv setConstantValue:&dv type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 0];
|
[cv setConstantValue:&dv type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 0];
|
||||||
[cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 1];
|
[cv setConstantValue:&nwg type:MTLDataTypeInt atIndex:FC_FLASH_ATTN_EXT_VEC_REDUCE + 1];
|
||||||
|
|
||||||
return ggml_metal_compile_kernel(backend, base, name, cv);
|
res = ggml_metal_compile_kernel(backend, base, name, cv);
|
||||||
|
|
||||||
|
[cv release];
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_UNUSED(op);
|
GGML_UNUSED(op);
|
||||||
@@ -5770,6 +5778,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
|
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
|
||||||
[cmd_buf retain];
|
[cmd_buf retain];
|
||||||
|
|
||||||
|
if (ctx->cmd_bufs[n_cb].obj) {
|
||||||
|
[ctx->cmd_bufs[n_cb].obj release];
|
||||||
|
}
|
||||||
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
||||||
|
|
||||||
[cmd_buf enqueue];
|
[cmd_buf enqueue];
|
||||||
|
|||||||
Reference in New Issue
Block a user