mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-31 08:51:55 +00:00
metal : fix check for bfloat tensor support
This commit is contained in:
@@ -95,7 +95,9 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder);
|
||||
|
||||
typedef struct ggml_metal_library * ggml_metal_library_t;
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev);
|
||||
ggml_metal_library_t ggml_metal_library_init (ggml_metal_device_t dev);
|
||||
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose);
|
||||
|
||||
void ggml_metal_library_free(ggml_metal_library_t lib);
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
|
||||
|
||||
@@ -303,6 +303,72 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
|
||||
if (source == NULL) {
|
||||
GGML_LOG_ERROR("%s: source is NULL\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
|
||||
id<MTLLibrary> library = nil;
|
||||
NSError * error = nil;
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
NSString * src = [[NSString alloc] initWithBytes:source
|
||||
length:strlen(source)
|
||||
encoding:NSUTF8StringEncoding];
|
||||
if (!src) {
|
||||
GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||
|
||||
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||
options.preprocessorMacros = prep;
|
||||
|
||||
library = [device newLibraryWithSource:src options:options error:&error];
|
||||
if (error) {
|
||||
if (verbose) {
|
||||
GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]);
|
||||
} else {
|
||||
GGML_LOG_ERROR("%s: error compiling source\n", __func__);
|
||||
}
|
||||
library = nil;
|
||||
}
|
||||
|
||||
[options release];
|
||||
}
|
||||
|
||||
[src release];
|
||||
|
||||
if (!library) {
|
||||
if (verbose) {
|
||||
GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__);
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (verbose) {
|
||||
GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
|
||||
}
|
||||
|
||||
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||
if (!res) {
|
||||
GGML_LOG_ERROR("%s: calloc failed\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
res->obj = library;
|
||||
res->device = device;
|
||||
res->pipelines = ggml_metal_pipelines_init();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ggml_metal_library_free(ggml_metal_library_t lib) {
|
||||
if (!lib) {
|
||||
return;
|
||||
@@ -474,12 +540,56 @@ ggml_metal_device_t ggml_metal_device_init(void) {
|
||||
|
||||
dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||
dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
if (getenv("GGML_METAL_BF16_DISABLE") != NULL) {
|
||||
dev->props.has_bfloat = false;
|
||||
}
|
||||
|
||||
dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
|
||||
if (getenv("GGML_METAL_TENSOR_DISABLE") != NULL) {
|
||||
dev->props.has_tensor = false;
|
||||
}
|
||||
|
||||
// try to compile a dummy tensor kernel to determine if the tensor API is supported for bfloat
|
||||
if (dev->props.has_tensor && dev->props.has_bfloat) {
|
||||
const char * src_tensor_bf16 = "\n"
|
||||
"#include <metal_stdlib> \n"
|
||||
"#include <metal_tensor> \n"
|
||||
"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
|
||||
" \n"
|
||||
"using namespace metal; \n"
|
||||
"using namespace mpp::tensor_ops; \n"
|
||||
" \n"
|
||||
"kernel void bfloat_dummy_kernel( \n"
|
||||
" tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
|
||||
" tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
|
||||
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
|
||||
"{ \n"
|
||||
" // Create slices for this threadgroup (no real computation performed). \n"
|
||||
" auto tA = A.slice(0, (int)tgid.y); \n"
|
||||
" auto tB = B.slice((int)tgid.x, 0); \n"
|
||||
" \n"
|
||||
" // Minimal matmul descriptor: 8×8 tile with dynamic K dimension. \n"
|
||||
" matmul2d< \n"
|
||||
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
|
||||
" execution_thread> mm; \n"
|
||||
" \n"
|
||||
" // Obtain a cooperative destination tensor of bfloat type. \n"
|
||||
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), bfloat>(); \n"
|
||||
" \n"
|
||||
" // Silence “unused variable” warnings. \n"
|
||||
" (void)cT; \n"
|
||||
"}";
|
||||
|
||||
GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
|
||||
ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);
|
||||
if (lib == NULL) {
|
||||
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
|
||||
dev->props.has_bfloat = false;
|
||||
} else {
|
||||
ggml_metal_library_free(lib);
|
||||
}
|
||||
}
|
||||
|
||||
dev->props.use_residency_sets = true;
|
||||
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||
dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
||||
|
||||
Reference in New Issue
Block a user