mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
metal : fix check for bfloat tensor support
This commit is contained in:
@@ -96,6 +96,8 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder);
|
|||||||
typedef struct ggml_metal_library * ggml_metal_library_t;
|
typedef struct ggml_metal_library * ggml_metal_library_t;
|
||||||
|
|
||||||
ggml_metal_library_t ggml_metal_library_init (ggml_metal_device_t dev);
|
ggml_metal_library_t ggml_metal_library_init (ggml_metal_device_t dev);
|
||||||
|
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose);
|
||||||
|
|
||||||
void ggml_metal_library_free(ggml_metal_library_t lib);
|
void ggml_metal_library_free(ggml_metal_library_t lib);
|
||||||
|
|
||||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
|
ggml_metal_pipeline_t ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
|
||||||
|
|||||||
@@ -303,6 +303,72 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
|
||||||
|
if (source == NULL) {
|
||||||
|
GGML_LOG_ERROR("%s: source is NULL\n", __func__);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
id<MTLDevice> device = ggml_metal_device_get_obj(dev);
|
||||||
|
id<MTLLibrary> library = nil;
|
||||||
|
NSError * error = nil;
|
||||||
|
|
||||||
|
const int64_t t_start = ggml_time_us();
|
||||||
|
|
||||||
|
NSString * src = [[NSString alloc] initWithBytes:source
|
||||||
|
length:strlen(source)
|
||||||
|
encoding:NSUTF8StringEncoding];
|
||||||
|
if (!src) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
@autoreleasepool {
|
||||||
|
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
|
||||||
|
|
||||||
|
MTLCompileOptions * options = [MTLCompileOptions new];
|
||||||
|
options.preprocessorMacros = prep;
|
||||||
|
|
||||||
|
library = [device newLibraryWithSource:src options:options error:&error];
|
||||||
|
if (error) {
|
||||||
|
if (verbose) {
|
||||||
|
GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]);
|
||||||
|
} else {
|
||||||
|
GGML_LOG_ERROR("%s: error compiling source\n", __func__);
|
||||||
|
}
|
||||||
|
library = nil;
|
||||||
|
}
|
||||||
|
|
||||||
|
[options release];
|
||||||
|
}
|
||||||
|
|
||||||
|
[src release];
|
||||||
|
|
||||||
|
if (!library) {
|
||||||
|
if (verbose) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (verbose) {
|
||||||
|
GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
|
||||||
|
if (!res) {
|
||||||
|
GGML_LOG_ERROR("%s: calloc failed\n", __func__);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
res->obj = library;
|
||||||
|
res->device = device;
|
||||||
|
res->pipelines = ggml_metal_pipelines_init();
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_metal_library_free(ggml_metal_library_t lib) {
|
void ggml_metal_library_free(ggml_metal_library_t lib) {
|
||||||
if (!lib) {
|
if (!lib) {
|
||||||
return;
|
return;
|
||||||
@@ -474,12 +540,56 @@ ggml_metal_device_t ggml_metal_device_init(void) {
|
|||||||
|
|
||||||
dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
|
||||||
dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||||
|
if (getenv("GGML_METAL_BF16_DISABLE") != NULL) {
|
||||||
|
dev->props.has_bfloat = false;
|
||||||
|
}
|
||||||
|
|
||||||
dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
|
dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
|
||||||
if (getenv("GGML_METAL_TENSOR_DISABLE") != NULL) {
|
if (getenv("GGML_METAL_TENSOR_DISABLE") != NULL) {
|
||||||
dev->props.has_tensor = false;
|
dev->props.has_tensor = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// try to compile a dummy tensor kernel to determine if the tensor API is supported for bfloat
|
||||||
|
if (dev->props.has_tensor && dev->props.has_bfloat) {
|
||||||
|
const char * src_tensor_bf16 = "\n"
|
||||||
|
"#include <metal_stdlib> \n"
|
||||||
|
"#include <metal_tensor> \n"
|
||||||
|
"#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
|
||||||
|
" \n"
|
||||||
|
"using namespace metal; \n"
|
||||||
|
"using namespace mpp::tensor_ops; \n"
|
||||||
|
" \n"
|
||||||
|
"kernel void bfloat_dummy_kernel( \n"
|
||||||
|
" tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
|
||||||
|
" tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
|
||||||
|
" uint2 tgid [[threadgroup_position_in_grid]]) \n"
|
||||||
|
"{ \n"
|
||||||
|
" // Create slices for this threadgroup (no real computation performed). \n"
|
||||||
|
" auto tA = A.slice(0, (int)tgid.y); \n"
|
||||||
|
" auto tB = B.slice((int)tgid.x, 0); \n"
|
||||||
|
" \n"
|
||||||
|
" // Minimal matmul descriptor: 8×8 tile with dynamic K dimension. \n"
|
||||||
|
" matmul2d< \n"
|
||||||
|
" matmul2d_descriptor(8, 8, dynamic_extent), \n"
|
||||||
|
" execution_thread> mm; \n"
|
||||||
|
" \n"
|
||||||
|
" // Obtain a cooperative destination tensor of bfloat type. \n"
|
||||||
|
" auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), bfloat>(); \n"
|
||||||
|
" \n"
|
||||||
|
" // Silence “unused variable” warnings. \n"
|
||||||
|
" (void)cT; \n"
|
||||||
|
"}";
|
||||||
|
|
||||||
|
GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
|
||||||
|
ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);
|
||||||
|
if (lib == NULL) {
|
||||||
|
GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
|
||||||
|
dev->props.has_bfloat = false;
|
||||||
|
} else {
|
||||||
|
ggml_metal_library_free(lib);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
dev->props.use_residency_sets = true;
|
dev->props.use_residency_sets = true;
|
||||||
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
|
||||||
dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
|
||||||
|
|||||||
Reference in New Issue
Block a user