mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	ggml : adjust mul_mat_f16 work memory (#1226)
* llama : minor - remove explicity int64_t cast * ggml : reduce memory buffer for F16 mul_mat when not using cuBLAS * ggml : add asserts to guard for incorrect wsize
This commit is contained in:
		
							
								
								
									
										9
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								Makefile
									
									
									
									
									
								
							| @@ -34,10 +34,15 @@ endif | |||||||
| # | # | ||||||
|  |  | ||||||
| # keep standard at C11 and C++11 | # keep standard at C11 and C++11 | ||||||
| CFLAGS   = -I.              -O3 -DNDEBUG -std=c11   -fPIC | CFLAGS   = -I.              -O3 -std=c11   -fPIC | ||||||
| CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC | CXXFLAGS = -I. -I./examples -O3 -std=c++11 -fPIC | ||||||
| LDFLAGS  = | LDFLAGS  = | ||||||
|  |  | ||||||
|  | ifndef LLAMA_DEBUG | ||||||
|  | 	CFLAGS   += -DNDEBUG | ||||||
|  | 	CXXFLAGS += -DNDEBUG | ||||||
|  | endif | ||||||
|  |  | ||||||
| # warnings | # warnings | ||||||
| CFLAGS   += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith | CFLAGS   += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith | ||||||
| CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar | CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -8245,8 +8245,6 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|         ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); |         ggml_fp16_t * d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size); | ||||||
|         ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); |         ggml_fp16_t * d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size); | ||||||
|         float       * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); |         float       * d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size); | ||||||
| #else |  | ||||||
|         float * const wdata = params->wdata; |  | ||||||
| #endif | #endif | ||||||
|         for (int64_t i03 = 0; i03 < ne03; i03++) { |         for (int64_t i03 = 0; i03 < ne03; i03++) { | ||||||
|             for (int64_t i02 = 0; i02 < ne02; i02++) { |             for (int64_t i02 = 0; i02 < ne02; i02++) { | ||||||
| @@ -8263,8 +8261,11 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|                             wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)); |                             wdata[id++] = GGML_FP32_TO_FP16(*(float *) ((char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11 + i00*nb10)); | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|  |                     assert(id*sizeof(ggml_fp16_t) <= params->wsize); | ||||||
|                 } |                 } | ||||||
| #else | #else | ||||||
|  |                 float * const wdata = params->wdata; | ||||||
|                 { |                 { | ||||||
|                     size_t id = 0; |                     size_t id = 0; | ||||||
|                     for (int64_t i01 = 0; i01 < ne01; ++i01) { |                     for (int64_t i01 = 0; i01 < ne01; ++i01) { | ||||||
| @@ -8272,6 +8273,8 @@ static void ggml_compute_forward_mul_mat_f16_f32( | |||||||
|                             wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); |                             wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00)); | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|  |                     assert(id*sizeof(float) <= params->wsize); | ||||||
|                 } |                 } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| @@ -8537,7 +8540,10 @@ static void ggml_compute_forward_mul_mat_q_f32( | |||||||
|                         dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); |                         dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00); | ||||||
|                         id += ne00; |                         id += ne00; | ||||||
|                     } |                     } | ||||||
|  |  | ||||||
|  |                     assert(id*sizeof(float) <= params->wsize); | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 const float * x = wdata; |                 const float * x = wdata; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| @@ -11571,10 +11577,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) | |||||||
|                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { |                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { | ||||||
|                                 node->n_tasks = 1; // TODO: this actually is doing nothing |                                 node->n_tasks = 1; // TODO: this actually is doing nothing | ||||||
|                                                    //       the threads are still spinning |                                                    //       the threads are still spinning | ||||||
|                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*MAX(ggml_nelements(node->src1), ggml_nelements(node->src0)); | #if defined(GGML_USE_CUBLAS) | ||||||
|                                 //printf("src0: ne0 = %d, ne1 = %d, ne = %d\n", node->src0->ne[0], node->src0->ne[1], node->src0->ne[0]*node->src0->ne[1]); |                                 // with cuBLAS, we need memory for the full 3D / 4D data of src1 | ||||||
|                                 //printf("src1: ne0 = %d, ne1 = %d, ne = %d\n", node->src1->ne[0], node->src1->ne[1], node->src1->ne[0]*node->src1->ne[1]); |                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); | ||||||
|                                 //printf("cur = %zu\n", cur); | #else | ||||||
|  |                                 // here we need memory just for single 2D matrix from src0 | ||||||
|  |                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); | ||||||
|  | #endif | ||||||
|                             } else { |                             } else { | ||||||
|                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); |                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1); | ||||||
|                             } |                             } | ||||||
|   | |||||||
| @@ -780,7 +780,7 @@ static bool kv_cache_init( | |||||||
|     const int n_embd  = hparams.n_embd; |     const int n_embd  = hparams.n_embd; | ||||||
|     const int n_layer = hparams.n_layer; |     const int n_layer = hparams.n_layer; | ||||||
|  |  | ||||||
|     const int64_t n_mem      = (int64_t)n_layer*n_ctx; |     const int64_t n_mem      = n_layer*n_ctx; | ||||||
|     const int64_t n_elements = n_embd*n_mem; |     const int64_t n_elements = n_embd*n_mem; | ||||||
|  |  | ||||||
|     cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); |     cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov