mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	cuda : loading models directly into VRAM, norm calculation on GPU, broadcasting for ggml_mul (#1483)
* Broadcasting for ggml_mul * CUDA kernel for ggml_mul, norms in VRAM * GPU weights not in RAM, direct loading with cuFile * fixup! GPU weights not in RAM, direct loading with cuFile * fixup! GPU weights not in RAM, direct loading with cuFile * define default model path once, sync path with readme (#1366) * ~7% faster Q5_1 AVX2 code (#1477) * convert.py: Support models which are stored in a single pytorch_model.bin (#1469) * Support models in a single pytorch_model.bin * Remove spurious line with typo * benchmark-matmul: Print the average of the test results (#1490) * Remove unused n_parts parameter (#1509) * Fixes #1511 lambda issue for w64devkit (mingw) (#1513) * Fix for w64devkit and mingw * make kv_f16 the default for api users (#1517) * minor : fix compile warnings * readme : adds WizardLM to the list of supported models (#1485) * main : make reverse prompt option act as a stop token in non-interactive mode (#1032) * Make reverse prompt option act as a stop token in non-interactive scenarios * Making requested review changes * Update gpt_params_parse and fix a merge error * Revert "Update gpt_params_parse and fix a merge error" This reverts commit2bb2ff1748. * Update gpt_params_parse and fix a merge error take 2 * examples : add persistent chat (#1495) * examples : add persistent chat * examples : fix whitespace --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * tests : add missing header * ggml : use F16 instead of F32 in Q4_0, Q4_1, Q8_0 (#1508) * ggml : use F16 instead of F32 in Q4_0, Q4_1 and Q8_0 * llama : bump LLAMA_FILE_VERSION to 3 * cuda : update Q4 and Q8 dequantize kernels * ggml : fix AVX dot products * readme : update performance table + hot topics * ggml : fix scalar implementation of Q4_1 dot * llama : fix compile warnings in llama_set_state_data() * llama : fix name shadowing and C4146 (#1526) * Fix name shadowing and C4146 * Fix if macros not using defined when required * Update llama-util.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update llama-util.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Code style Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Fix for mingw (#1462) * llama : add llama_init_backend() API (close #1527) * feature : add blis and other BLAS implementation support (#1502) * feature: add blis support * feature: allow all BLA_VENDOR to be assigned in cmake arguments. align with whisper.cpp pr 927 * fix: version detection for BLA_SIZEOF_INTEGER, recover min version of cmake * Fix typo in INTEGER Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Revert "feature : add blis and other BLAS implementation support (#1502)" This reverts commit07e9ace0f9. * GPU weights not in RAM, direct loading with cuFile * llama : code style fixes + progress print fix * ggml : ggml_mul better broadcast support * cmake : workarounds for cufile when CMake version < 3.25 * gg rebase fixup * Loop in llama.cpp, fixed progress callback * Attempt clang-tidy fix * llama : fix vram size computation * Add forgotten fclose() --------- Co-authored-by: András Salamon <ott2@users.noreply.github.com> Co-authored-by: Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com> Co-authored-by: Tom Jobbins <784313+TheBloke@users.noreply.github.com> Co-authored-by: rankaiyx <rankaiyx@rankaiyx.com> Co-authored-by: Stephan Walter <stephan@walter.name> Co-authored-by: DannyDaemonic <DannyDaemonic@gmail.com> Co-authored-by: Erik Scholz <Green-Sky@users.noreply.github.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: David Kennedy <dakennedyd@gmail.com> Co-authored-by: Jason McCartney <jmac@theroot.org> Co-authored-by: Evan Jones <evan.q.jones@gmail.com> Co-authored-by: Maxime <672982+maximegmd@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zenix <zenixls2@gmail.com>
This commit is contained in:
		
							
								
								
									
										123
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										123
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -83,9 +83,19 @@ typedef struct { | |||||||
| } block_q8_0; | } block_q8_0; | ||||||
| static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); | static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); | ||||||
|  |  | ||||||
|  | #define CUDA_MUL_BLOCK_SIZE 256 | ||||||
| #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 | #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 | ||||||
| #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec | #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec | ||||||
|  |  | ||||||
|  | static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { | ||||||
|  |     const int i = blockDim.x*blockIdx.x + threadIdx.x; | ||||||
|  |  | ||||||
|  |     if (i >= kx) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |     dst[i] = x[i] * y[i%ky]; | ||||||
|  | } | ||||||
|  |  | ||||||
| static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ | static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ | ||||||
|     const block_q4_0 * x = (const block_q4_0 *) vx; |     const block_q4_0 * x = (const block_q4_0 *) vx; | ||||||
|  |  | ||||||
| @@ -228,6 +238,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { | ||||||
|  |     const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; | ||||||
|  |     mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky); | ||||||
|  | } | ||||||
|  |  | ||||||
| static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { | static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { | ||||||
|     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; |     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; | ||||||
|     dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k); |     dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k); | ||||||
| @@ -467,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|  |     GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA); | ||||||
|  |     const int64_t ne00 = src0->ne[0]; | ||||||
|  |     const int64_t ne01 = src0->ne[1]; | ||||||
|  |     const int64_t ne02 = src0->ne[2]; | ||||||
|  |     const int64_t ne03 = src0->ne[2]; | ||||||
|  |     const int64_t ne0 = ne00 * ne01 * ne02 * ne03; | ||||||
|  |     const int64_t ne10 = src1->ne[0]; | ||||||
|  |     const int64_t ne11 = src1->ne[1]; | ||||||
|  |     const int64_t ne12 = src1->ne[2]; | ||||||
|  |     const int64_t ne13 = src1->ne[3]; | ||||||
|  |     const int nb2  = dst->nb[2]; | ||||||
|  |     const int nb3  = dst->nb[3]; | ||||||
|  |     size_t x_size, d_size; | ||||||
|  |  | ||||||
|  |     float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0 | ||||||
|  |     float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted. | ||||||
|  |     float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst | ||||||
|  |  | ||||||
|  |     for (int64_t i03 = 0; i03 < ne03; i03++) { | ||||||
|  |         for (int64_t i02 = 0; i02 < ne02; i02++) { | ||||||
|  |             const int i0 = i03*ne02 + i02; | ||||||
|  |             float * c_X2 = d_X + i0*ne01*ne00; | ||||||
|  |             float * c_D2 = d_D + i0*ne01*ne00; | ||||||
|  |  | ||||||
|  |             cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS]; | ||||||
|  |             cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS]; | ||||||
|  |             cudaEvent_t  cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS]; | ||||||
|  |  | ||||||
|  |             // copy src0 to device | ||||||
|  |             CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2)); | ||||||
|  |             CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); | ||||||
|  |  | ||||||
|  |             // wait for data | ||||||
|  |             CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); | ||||||
|  |  | ||||||
|  |             for (int64_t i01 = 0; i01 < ne01; i01++) { | ||||||
|  |                 const int64_t i13 = i03%ne13; | ||||||
|  |                 const int64_t i12 = i02%ne12; | ||||||
|  |                 const int64_t i11 = i01%ne11; | ||||||
|  |                 const int i1 = i13*ne12*ne11 + i12*ne11 + i11; | ||||||
|  |  | ||||||
|  |                 float * c_X1 = c_X2 + i01*ne00; | ||||||
|  |                 float * c_Y = d_Y + i1*ne10; | ||||||
|  |                 float * c_D1 = c_D2 + i01*ne00; | ||||||
|  |  | ||||||
|  |                 // compute | ||||||
|  |                 mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream); | ||||||
|  |                 CUDA_CHECK(cudaGetLastError()); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             // copy dst to host | ||||||
|  |             float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); | ||||||
|  |             CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream)); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     CUDA_CHECK(cudaDeviceSynchronize()); | ||||||
|  |     ggml_cuda_pool_free(d_X, x_size); | ||||||
|  |     ggml_cuda_pool_free(d_D, d_size); | ||||||
|  | } | ||||||
|  |  | ||||||
| static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|     const int64_t ne00 = src0->ne[0]; |     const int64_t ne00 = src0->ne[0]; | ||||||
|     const int64_t ne01 = src0->ne[1]; |     const int64_t ne01 = src0->ne[1]; | ||||||
| @@ -724,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor | |||||||
|     ggml_cuda_pool_free(d_Q, q_size); |     ggml_cuda_pool_free(d_Q, q_size); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { | ||||||
|  |     GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); | ||||||
|  |     ggml_cuda_mul_f32(src0, src1, dst); | ||||||
|  | } | ||||||
|  |  | ||||||
| bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { | bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { | ||||||
|     const int64_t ne10 = src1->ne[0]; |     const int64_t ne10 = src1->ne[0]; | ||||||
|  |  | ||||||
| @@ -797,14 +878,48 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) { | |||||||
|     const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); |     const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); | ||||||
|  |  | ||||||
|     size_t q_size; |     size_t q_size; | ||||||
|     char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); |     char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); | ||||||
|  |  | ||||||
|     cudaStream_t cudaStream2 = g_cudaStreams2[0]; |     cudaStream_t cudaStream2 = g_cudaStreams2[0]; | ||||||
|  |  | ||||||
|     // copy tensor to device |     // copy tensor to device | ||||||
|     CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); |     for (int64_t i3 = 0; i3 < ne3; i3++) { | ||||||
|     CUDA_CHECK(cudaDeviceSynchronize()); |         for (int64_t i2 = 0; i2 < ne2; i2++) { | ||||||
|  |             int i = i3*ne2 + i2; | ||||||
|  |             CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2)); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     tensor->data = d_Q; |     tensor->data = dst; | ||||||
|     tensor->backend = GGML_BACKEND_CUDA; |     tensor->backend = GGML_BACKEND_CUDA; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) { | ||||||
|  |     FILE * fp = fopen(fname, "rb"); | ||||||
|  |  | ||||||
|  |     const size_t size = ggml_nbytes(tensor); | ||||||
|  |  | ||||||
|  |     void * buf; | ||||||
|  |     CUDA_CHECK(cudaMalloc(&buf, size)); | ||||||
|  |     void * buf_host = malloc(size); | ||||||
|  |  | ||||||
|  | #ifdef _WIN32 | ||||||
|  |     int ret = _fseeki64(fp, (__int64) offset, SEEK_SET); | ||||||
|  | #else | ||||||
|  |     int ret = fseek(fp, (long) offset, SEEK_SET); | ||||||
|  | #endif | ||||||
|  |     GGML_ASSERT(ret == 0); // same | ||||||
|  |  | ||||||
|  |     size_t ret2 = fread(buf_host, size, 1, fp); | ||||||
|  |     if (ret2 != 1) { | ||||||
|  |         fprintf(stderr, "unexpectedly reached end of file"); | ||||||
|  |         exit(1); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); | ||||||
|  |     cudaDeviceSynchronize(); | ||||||
|  |  | ||||||
|  |     tensor->data = buf; | ||||||
|  |     free(buf_host); | ||||||
|  |     fclose(fp); | ||||||
|  | } | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ extern "C" { | |||||||
|  |  | ||||||
| void   ggml_init_cublas(void); | void   ggml_init_cublas(void); | ||||||
|  |  | ||||||
|  | void   ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); | ||||||
| bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); | bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); | ||||||
| size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); | size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); | ||||||
| void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize); | void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize); | ||||||
| @@ -15,6 +16,7 @@ void * ggml_cuda_host_malloc(size_t size); | |||||||
| void   ggml_cuda_host_free(void * ptr); | void   ggml_cuda_host_free(void * ptr); | ||||||
|  |  | ||||||
| void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); | void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); | ||||||
|  | void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset); | ||||||
|  |  | ||||||
| #ifdef  __cplusplus | #ifdef  __cplusplus | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										90
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										90
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -3776,6 +3776,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g | |||||||
|         (t1->ne[3]%t0->ne[3] == 0); |         (t1->ne[3]%t0->ne[3] == 0); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { | ||||||
|  |     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); | ||||||
|  |  | ||||||
|  |     return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1); | ||||||
|  | } | ||||||
|  |  | ||||||
| static inline int ggml_up32(int n) { | static inline int ggml_up32(int n) { | ||||||
|     return (n + 31) & ~31; |     return (n + 31) & ~31; | ||||||
| } | } | ||||||
| @@ -4658,11 +4664,15 @@ struct ggml_tensor * ggml_mul_impl( | |||||||
|         struct ggml_tensor * a, |         struct ggml_tensor * a, | ||||||
|         struct ggml_tensor * b, |         struct ggml_tensor * b, | ||||||
|         bool inplace) { |         bool inplace) { | ||||||
|     GGML_ASSERT(ggml_are_same_shape(a, b)); |     // TODO: support less-strict constraint | ||||||
|  |     //       GGML_ASSERT(ggml_can_repeat(b, a)); | ||||||
|  |     GGML_ASSERT(ggml_can_repeat_rows(b, a)); | ||||||
|  |  | ||||||
|     bool is_node = false; |     bool is_node = false; | ||||||
|  |  | ||||||
|     if (!inplace && (a->grad || b->grad)) { |     if (!inplace && (a->grad || b->grad)) { | ||||||
|  |         // TODO: support backward pass for broadcasting | ||||||
|  |         GGML_ASSERT(ggml_are_same_shape(a, b)); | ||||||
|         is_node = true; |         is_node = true; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -7960,7 +7970,7 @@ static void ggml_compute_forward_mul_f32( | |||||||
|         const struct ggml_tensor * src0, |         const struct ggml_tensor * src0, | ||||||
|         const struct ggml_tensor * src1, |         const struct ggml_tensor * src1, | ||||||
|         struct ggml_tensor * dst) { |         struct ggml_tensor * dst) { | ||||||
|     assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); |     GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); | ||||||
|  |  | ||||||
|     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { |     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { | ||||||
|         return; |         return; | ||||||
| @@ -7968,10 +7978,25 @@ static void ggml_compute_forward_mul_f32( | |||||||
|     const int ith = params->ith; |     const int ith = params->ith; | ||||||
|     const int nth = params->nth; |     const int nth = params->nth; | ||||||
|  |  | ||||||
|     const int nr  = ggml_nrows(src0); | #ifdef GGML_USE_CUBLAS | ||||||
|     const int64_t ne0 = src0->ne[0]; |     if (src1->backend == GGML_BACKEND_CUDA) { | ||||||
|     const int64_t ne1 = src0->ne[1]; |         if (ith == 0) { | ||||||
|     const int64_t ne2 = src0->ne[2]; |             ggml_cuda_mul(src0, src1, dst); | ||||||
|  |         } | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |     const int64_t nr = ggml_nrows(src0); | ||||||
|  |  | ||||||
|  |     const int64_t ne00 = src0->ne[0]; | ||||||
|  |     const int64_t ne01 = src0->ne[1]; | ||||||
|  |     const int64_t ne02 = src0->ne[2]; | ||||||
|  |  | ||||||
|  |     const int64_t ne10 = src1->ne[0]; | ||||||
|  |     const int64_t ne11 = src1->ne[1]; | ||||||
|  |     const int64_t ne12 = src1->ne[2]; | ||||||
|  |     const int64_t ne13 = src1->ne[3]; | ||||||
|  |  | ||||||
|     const size_t nb00 = src0->nb[0]; |     const size_t nb00 = src0->nb[0]; | ||||||
|     const size_t nb01 = src0->nb[1]; |     const size_t nb01 = src0->nb[1]; | ||||||
| @@ -7990,44 +8015,51 @@ static void ggml_compute_forward_mul_f32( | |||||||
|  |  | ||||||
|     GGML_ASSERT( nb0 == sizeof(float)); |     GGML_ASSERT( nb0 == sizeof(float)); | ||||||
|     GGML_ASSERT(nb00 == sizeof(float)); |     GGML_ASSERT(nb00 == sizeof(float)); | ||||||
|  |     GGML_ASSERT(ne00 == ne10); | ||||||
|  |  | ||||||
|     if (nb10 == sizeof(float)) { |     if (nb10 == sizeof(float)) { | ||||||
|         for (int ir = ith; ir < nr; ir += nth) { |         for (int64_t ir = ith; ir < nr; ir += nth) { | ||||||
|             // src0, src1 and dst are same shape => same indices |             // src0 and dst are same shape => same indices | ||||||
|             const int i3 = ir/(ne2*ne1); |             const int64_t i03 = ir/(ne02*ne01); | ||||||
|             const int i2 = (ir - i3*ne2*ne1)/ne1; |             const int64_t i02 = (ir - i03*ne02*ne01)/ne01; | ||||||
|             const int i1 = (ir - i3*ne2*ne1 - i2*ne1); |             const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); | ||||||
|  |  | ||||||
|  |             const int64_t i13 = i03 % ne13; | ||||||
|  |             const int64_t i12 = i02 % ne12; | ||||||
|  |             const int64_t i11 = i01 % ne11; | ||||||
|  |  | ||||||
|  |             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 ); | ||||||
|  |             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); | ||||||
|  |             float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); | ||||||
|  |  | ||||||
| #ifdef GGML_USE_ACCELERATE | #ifdef GGML_USE_ACCELERATE | ||||||
|             UNUSED(ggml_vec_mul_f32); |             UNUSED(ggml_vec_mul_f32); | ||||||
|  |  | ||||||
|             vDSP_vmul( |             vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr,  1, ne00); | ||||||
|                     (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, |  | ||||||
|                     (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, |  | ||||||
|                     (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1, |  | ||||||
|                     ne0); |  | ||||||
| #else | #else | ||||||
|             ggml_vec_mul_f32(ne0, |             ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr); | ||||||
|                     (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), |  | ||||||
|                     (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), |  | ||||||
|                     (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); |  | ||||||
| #endif | #endif | ||||||
|                 // } |                 // } | ||||||
|             // } |             // } | ||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         // src1 is not contiguous |         // src1 is not contiguous | ||||||
|         for (int ir = ith; ir < nr; ir += nth) { |         for (int64_t ir = ith; ir < nr; ir += nth) { | ||||||
|             // src0, src1 and dst are same shape => same indices |             // src0 and dst are same shape => same indices | ||||||
|             const int i3 = ir/(ne2*ne1); |             // src1 is broadcastable across src0 and dst in i1, i2, i3 | ||||||
|             const int i2 = (ir - i3*ne2*ne1)/ne1; |             const int64_t i03 = ir/(ne02*ne01); | ||||||
|             const int i1 = (ir - i3*ne2*ne1 - i2*ne1); |             const int64_t i02 = (ir - i03*ne02*ne01)/ne01; | ||||||
|  |             const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); | ||||||
|  |  | ||||||
|             float * dst_ptr  = (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ); |             const int64_t i13 = i03 % ne13; | ||||||
|             float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); |             const int64_t i12 = i02 % ne12; | ||||||
|             for (int i0 = 0; i0 < ne0; i0++) { |             const int64_t i11 = i01 % ne11; | ||||||
|                 float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); |  | ||||||
|  |             float * dst_ptr  = (float *) ((char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1 ); | ||||||
|  |             float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); | ||||||
|  |  | ||||||
|  |             for (int64_t i0 = 0; i0 < ne00; i0++) { | ||||||
|  |                 float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); | ||||||
|  |  | ||||||
|                 dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); |                 dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); | ||||||
|             } |             } | ||||||
|   | |||||||
| @@ -172,7 +172,7 @@ struct llama_mmap { | |||||||
| #ifdef _POSIX_MAPPED_FILES | #ifdef _POSIX_MAPPED_FILES | ||||||
|     static constexpr bool SUPPORTED = true; |     static constexpr bool SUPPORTED = true; | ||||||
|  |  | ||||||
|     llama_mmap(struct llama_file * file, bool prefetch = true) { |     llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */) { | ||||||
|         size = file->size; |         size = file->size; | ||||||
|         int fd = fileno(file->fp); |         int fd = fileno(file->fp); | ||||||
|         int flags = MAP_SHARED; |         int flags = MAP_SHARED; | ||||||
| @@ -184,9 +184,9 @@ struct llama_mmap { | |||||||
|             throw std::runtime_error(format("mmap failed: %s", strerror(errno))); |             throw std::runtime_error(format("mmap failed: %s", strerror(errno))); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (prefetch) { |         if (prefetch > 0) { | ||||||
|             // Advise the kernel to preload the mapped memory |             // Advise the kernel to preload the mapped memory | ||||||
|             if (madvise(addr, file->size, MADV_WILLNEED)) { |             if (madvise(addr, std::min(file->size, prefetch), MADV_WILLNEED)) { | ||||||
|                 fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n", |                 fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n", | ||||||
|                         strerror(errno)); |                         strerror(errno)); | ||||||
|             } |             } | ||||||
|   | |||||||
							
								
								
									
										199
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										199
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | |||||||
| // Defines fileno on msys: | // Defines fileno on msys: | ||||||
| #ifndef _GNU_SOURCE | #ifndef _GNU_SOURCE | ||||||
| #define _GNU_SOURCE | #define _GNU_SOURCE | ||||||
|  | #include <cstddef> | ||||||
| #include <cstdint> | #include <cstdint> | ||||||
| #include <cstdio> | #include <cstdio> | ||||||
| #endif | #endif | ||||||
| @@ -645,7 +646,7 @@ struct llama_model_loader { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne) { |     struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne, ggml_backend backend) { | ||||||
|         auto it = tensors_map.name_to_idx.find(name); |         auto it = tensors_map.name_to_idx.find(name); | ||||||
|         if (it == tensors_map.name_to_idx.end()) { |         if (it == tensors_map.name_to_idx.end()) { | ||||||
|             throw format("llama.cpp: tensor '%s' is missing from model", name.c_str()); |             throw format("llama.cpp: tensor '%s' is missing from model", name.c_str()); | ||||||
| @@ -656,10 +657,10 @@ struct llama_model_loader { | |||||||
|                          name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str()); |                          name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str()); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         return get_tensor_for(lt); |         return get_tensor_for(lt, backend); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     struct ggml_tensor * get_tensor_for(llama_load_tensor & lt) { |     struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) { | ||||||
|         struct ggml_tensor * tensor; |         struct ggml_tensor * tensor; | ||||||
|         if (lt.ne.size() == 2) { |         if (lt.ne.size() == 2) { | ||||||
|             tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1)); |             tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1)); | ||||||
| @@ -669,6 +670,7 @@ struct llama_model_loader { | |||||||
|         } |         } | ||||||
|         ggml_set_name(tensor, lt.name.c_str()); |         ggml_set_name(tensor, lt.name.c_str()); | ||||||
|         LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor |         LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor | ||||||
|  |         tensor->backend = backend; | ||||||
|         lt.ggml_tensor = tensor; |         lt.ggml_tensor = tensor; | ||||||
|         num_ggml_tensors_created++; |         num_ggml_tensors_created++; | ||||||
|         return tensor; |         return tensor; | ||||||
| @@ -682,12 +684,16 @@ struct llama_model_loader { | |||||||
|  |  | ||||||
|     void load_all_data(llama_progress_callback progress_callback, void *  progress_callback_user_data, llama_mlock * lmlock) { |     void load_all_data(llama_progress_callback progress_callback, void *  progress_callback_user_data, llama_mlock * lmlock) { | ||||||
|         size_t data_size = 0; |         size_t data_size = 0; | ||||||
|  |         size_t prefetch_size = 0; | ||||||
|         for (const llama_load_tensor & lt : tensors_map.tensors) { |         for (const llama_load_tensor & lt : tensors_map.tensors) { | ||||||
|             data_size += lt.size; |             data_size += lt.size; | ||||||
|  |             if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) { | ||||||
|  |                 prefetch_size += lt.size; | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (use_mmap) { |         if (use_mmap) { | ||||||
|             mapping.reset(new llama_mmap(&file_loaders.at(0)->file)); |             mapping.reset(new llama_mmap(&file_loaders.at(0)->file, prefetch_size)); | ||||||
|             if (!lmlock) { |             if (!lmlock) { | ||||||
|                 // Don't call the callback since the actual loading will be lazy |                 // Don't call the callback since the actual loading will be lazy | ||||||
|                 // and we can't measure it. |                 // and we can't measure it. | ||||||
| @@ -700,6 +706,9 @@ struct llama_model_loader { | |||||||
|  |  | ||||||
|         size_t done_size = 0; |         size_t done_size = 0; | ||||||
|         for (llama_load_tensor & lt : tensors_map.tensors) { |         for (llama_load_tensor & lt : tensors_map.tensors) { | ||||||
|  |             if (lt.ggml_tensor->backend != GGML_BACKEND_CPU) { | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|             if (progress_callback) { |             if (progress_callback) { | ||||||
|                 progress_callback((float) done_size / data_size, progress_callback_user_data); |                 progress_callback((float) done_size / data_size, progress_callback_user_data); | ||||||
|             } |             } | ||||||
| @@ -712,9 +721,6 @@ struct llama_model_loader { | |||||||
|                 lmlock->grow_to(done_size); |                 lmlock->grow_to(done_size); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         if (progress_callback) { |  | ||||||
|             progress_callback(1.0f, progress_callback_user_data); |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     void load_data_for(llama_load_tensor & lt) { |     void load_data_for(llama_load_tensor & lt) { | ||||||
| @@ -969,27 +975,7 @@ static void llama_model_load_internal( | |||||||
|     size_t ctx_size; |     size_t ctx_size; | ||||||
|     size_t mmapped_size; |     size_t mmapped_size; | ||||||
|     ml->calc_sizes(&ctx_size, &mmapped_size); |     ml->calc_sizes(&ctx_size, &mmapped_size); | ||||||
|     fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/1024.0/1024.0); |     fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); | ||||||
|  |  | ||||||
|     // print memory requirements |  | ||||||
|     { |  | ||||||
|         const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; |  | ||||||
|  |  | ||||||
|         // this is the total memory required to run the inference |  | ||||||
|         const size_t mem_required = |  | ||||||
|             ctx_size + |  | ||||||
|             mmapped_size + |  | ||||||
|             MEM_REQ_SCRATCH0().at(model.type) + |  | ||||||
|             MEM_REQ_SCRATCH1().at(model.type) + |  | ||||||
|             MEM_REQ_EVAL().at(model.type); |  | ||||||
|  |  | ||||||
|         // this is the memory required by one llama_state |  | ||||||
|         const size_t mem_required_state = |  | ||||||
|             scale*MEM_REQ_KV_SELF().at(model.type); |  | ||||||
|  |  | ||||||
|         fprintf(stderr, "%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__, |  | ||||||
|                 mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // create the ggml context |     // create the ggml context | ||||||
|     { |     { | ||||||
| @@ -1011,7 +997,14 @@ static void llama_model_load_internal( | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | #ifdef GGML_USE_CUBLAS | ||||||
|  | #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CUDA | ||||||
|  | #else | ||||||
|  | #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU | ||||||
|  | #endif | ||||||
|  |  | ||||||
|     // prepare memory for the weights |     // prepare memory for the weights | ||||||
|  |     size_t vram_total = 0; | ||||||
|     { |     { | ||||||
|         const uint32_t n_embd  = hparams.n_embd; |         const uint32_t n_embd  = hparams.n_embd; | ||||||
|         const uint32_t n_layer = hparams.n_layer; |         const uint32_t n_layer = hparams.n_layer; | ||||||
| @@ -1019,33 +1012,87 @@ static void llama_model_load_internal( | |||||||
|  |  | ||||||
|         ml->ggml_ctx = ctx; |         ml->ggml_ctx = ctx; | ||||||
|  |  | ||||||
|         model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}); |         model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); | ||||||
|         model.norm           = ml->get_tensor("norm.weight",           {n_embd}); |         model.norm           = ml->get_tensor("norm.weight",           {n_embd},          GGML_BACKEND_CPU); | ||||||
|         model.output         = ml->get_tensor("output.weight",         {n_embd, n_vocab}); |  | ||||||
|  |         // "output" tensor | ||||||
|  |         { | ||||||
|  |             ggml_backend backend_output; | ||||||
|  |             if (n_gpu_layers > int(n_layer)) { // NOLINT | ||||||
|  |                 backend_output = LLAMA_BACKEND_OFFLOAD; | ||||||
|  |             } else { | ||||||
|  |                 backend_output = GGML_BACKEND_CPU; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const int i_gpu_start = n_layer - n_gpu_layers; | ||||||
|  |  | ||||||
|         model.layers.resize(n_layer); |         model.layers.resize(n_layer); | ||||||
|         for (uint32_t i = 0; i < n_layer; ++i) { |         for (uint32_t i = 0; i < n_layer; ++i) { | ||||||
|  |             const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; | ||||||
|  |  | ||||||
|             auto & layer = model.layers[i]; |             auto & layer = model.layers[i]; | ||||||
|  |  | ||||||
|             std::string layers_i = "layers." + std::to_string(i); |             std::string layers_i = "layers." + std::to_string(i); | ||||||
|  |  | ||||||
|             layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}); |             layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); | ||||||
|  |  | ||||||
|             layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}); |             layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend); | ||||||
|             layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}); |             layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend); | ||||||
|             layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}); |             layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend); | ||||||
|             layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}); |             layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend); | ||||||
|  |  | ||||||
|             layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}); |             layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); | ||||||
|  |  | ||||||
|             layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd,   n_ff}); |             layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd,   n_ff},   backend); | ||||||
|             layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {  n_ff,   n_embd}); |             layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {  n_ff,   n_embd}, backend); | ||||||
|             layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd,   n_ff}); |             layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd,   n_ff},   backend); | ||||||
|  |  | ||||||
|  |             if (backend == GGML_BACKEND_CUDA) { | ||||||
|  |                 vram_total += | ||||||
|  |                     ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk)             + | ||||||
|  |                     ggml_nbytes(layer.wv)             + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) + | ||||||
|  |                     ggml_nbytes(layer.w1)             + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     ml->done_getting_tensors(); |     ml->done_getting_tensors(); | ||||||
|  |  | ||||||
|  |     // print memory requirements | ||||||
|  |     { | ||||||
|  |         const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; | ||||||
|  |  | ||||||
|  |         // this is the total memory required to run the inference | ||||||
|  |         const size_t mem_required = | ||||||
|  |             ctx_size + | ||||||
|  |             mmapped_size - vram_total + // weights in VRAM not in memory | ||||||
|  |             MEM_REQ_SCRATCH0().at(model.type) + | ||||||
|  |             MEM_REQ_SCRATCH1().at(model.type) + | ||||||
|  |             MEM_REQ_EVAL().at(model.type); | ||||||
|  |  | ||||||
|  |         // this is the memory required by one llama_state | ||||||
|  |         const size_t mem_required_state = | ||||||
|  |             scale*MEM_REQ_KV_SELF().at(model.type); | ||||||
|  |  | ||||||
|  |         fprintf(stderr, "%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__, | ||||||
|  |                 mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); | ||||||
|  |  | ||||||
|  | #ifdef GGML_USE_CUBLAS | ||||||
|  |         const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); | ||||||
|  |  | ||||||
|  |         fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu); | ||||||
|  |         if (n_gpu_layers > (int) hparams.n_layer) { | ||||||
|  |             fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__); | ||||||
|  |         } | ||||||
|  |         fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); | ||||||
|  | #else | ||||||
|  |         (void) n_gpu_layers; | ||||||
|  | #endif | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // populate `tensors_by_name` |     // populate `tensors_by_name` | ||||||
|     for (llama_load_tensor & lt : ml->tensors_map.tensors) { |     for (llama_load_tensor & lt : ml->tensors_map.tensors) { | ||||||
|         model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor); |         model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor); | ||||||
| @@ -1053,36 +1100,34 @@ static void llama_model_load_internal( | |||||||
|  |  | ||||||
|     ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); |     ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); | ||||||
|  |  | ||||||
|     model.mapping = std::move(ml->mapping); |  | ||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
|     { |     { | ||||||
|         const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); |         size_t done_size = 0; | ||||||
|  |         size_t data_size = 0; | ||||||
|         fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu); |         for (llama_load_tensor & lt : ml->tensors_map.tensors) { | ||||||
|  |             data_size += lt.size; | ||||||
|         size_t vram_total = 0; |             if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) { | ||||||
|  |                 done_size += lt.size; | ||||||
|         for (int i = 0; i < n_gpu; ++i) { |             } | ||||||
|             const auto & layer = model.layers[i]; |  | ||||||
|  |  | ||||||
|             ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq); |  | ||||||
|             ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk); |  | ||||||
|             ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv); |  | ||||||
|             ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo); |  | ||||||
|             ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1); |  | ||||||
|             ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2); |  | ||||||
|             ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3); |  | ||||||
|         } |         } | ||||||
|         if (n_gpu_layers > (int) hparams.n_layer) { |         for (llama_load_tensor & lt : ml->tensors_map.tensors) { | ||||||
|             fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__); |             if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) { | ||||||
|             ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output); |                 continue; | ||||||
|  |             } | ||||||
|  |             if (progress_callback) { | ||||||
|  |                 progress_callback((float) done_size / data_size, progress_callback_user_data); | ||||||
|  |             } | ||||||
|  |             ggml_cuda_load_data(fname.c_str(), lt.ggml_tensor, lt.shards.at(0).file_off); | ||||||
|  |             done_size += lt.size; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); |  | ||||||
|     } |     } | ||||||
| #else | #endif // GGML_USE_CUBLAS | ||||||
|     (void) n_gpu_layers; |  | ||||||
| #endif |     if (progress_callback) { | ||||||
|  |         progress_callback(1.0f, progress_callback_user_data); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     model.mapping = std::move(ml->mapping); | ||||||
|  |  | ||||||
|     // loading time will be recalculate after the first eval, so |     // loading time will be recalculate after the first eval, so | ||||||
|     // we take page faults deferred by mmap() into consideration |     // we take page faults deferred by mmap() into consideration | ||||||
| @@ -1181,10 +1226,8 @@ static bool llama_eval_internal( | |||||||
|         { |         { | ||||||
|             cur = ggml_rms_norm(ctx0, inpL); |             cur = ggml_rms_norm(ctx0, inpL); | ||||||
|  |  | ||||||
|             // cur = attention_norm*cur |             // cur = cur*attention_norm(broadcasted) | ||||||
|             cur = ggml_mul(ctx0, |             cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); | ||||||
|                         ggml_repeat(ctx0, model.layers[il].attention_norm, cur), |  | ||||||
|                         cur); |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // self-attention |         // self-attention | ||||||
| @@ -1291,10 +1334,8 @@ static bool llama_eval_internal( | |||||||
|             { |             { | ||||||
|                 cur = ggml_rms_norm(ctx0, inpFF); |                 cur = ggml_rms_norm(ctx0, inpFF); | ||||||
|  |  | ||||||
|                 // cur = ffn_norm*cur |                 // cur = cur*ffn_norm(broadcasted) | ||||||
|                 cur = ggml_mul(ctx0, |                 cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); | ||||||
|                         ggml_repeat(ctx0, model.layers[il].ffn_norm, cur), |  | ||||||
|                         cur); |  | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             struct ggml_tensor * tmp = ggml_mul_mat(ctx0, |             struct ggml_tensor * tmp = ggml_mul_mat(ctx0, | ||||||
| @@ -1331,10 +1372,8 @@ static bool llama_eval_internal( | |||||||
|  |  | ||||||
|         inpL = ggml_rms_norm(ctx0, inpL); |         inpL = ggml_rms_norm(ctx0, inpL); | ||||||
|  |  | ||||||
|         // inpL = norm*inpL |         // inpL = inpL*norm(broadcasted) | ||||||
|         inpL = ggml_mul(ctx0, |         inpL = ggml_mul(ctx0, inpL, model.norm); | ||||||
|                     ggml_repeat(ctx0, model.norm, inpL), |  | ||||||
|                     inpL); |  | ||||||
|  |  | ||||||
|         embeddings = inpL; |         embeddings = inpL; | ||||||
|     } |     } | ||||||
| @@ -2158,7 +2197,7 @@ struct llama_context * llama_init_from_file( | |||||||
|             unsigned * cur_percentage_p = (unsigned *) ctx; |             unsigned * cur_percentage_p = (unsigned *) ctx; | ||||||
|             unsigned percentage = (unsigned) (100 * progress); |             unsigned percentage = (unsigned) (100 * progress); | ||||||
|             while (percentage > *cur_percentage_p) { |             while (percentage > *cur_percentage_p) { | ||||||
|                 ++*cur_percentage_p; |                 *cur_percentage_p = percentage; | ||||||
|                 fprintf(stderr, "."); |                 fprintf(stderr, "."); | ||||||
|                 fflush(stderr); |                 fflush(stderr); | ||||||
|                 if (percentage >= 100) { |                 if (percentage >= 100) { | ||||||
| @@ -2315,7 +2354,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * | |||||||
|  |  | ||||||
|         // maybe this should in llama_model_loader |         // maybe this should in llama_model_loader | ||||||
|         if (model_loader->use_mmap) { |         if (model_loader->use_mmap) { | ||||||
|             model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ false)); |             model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ 0)); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -2408,7 +2447,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * | |||||||
|                 } |                 } | ||||||
|                 size_t idx = model_loader->tensors_map.name_to_idx[base_name]; |                 size_t idx = model_loader->tensors_map.name_to_idx[base_name]; | ||||||
|                 llama_load_tensor & lt = model_loader->tensors_map.tensors[idx]; |                 llama_load_tensor & lt = model_loader->tensors_map.tensors[idx]; | ||||||
|                 base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }); |                 base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); | ||||||
|                 lt.data = (uint8_t *) lt.ggml_tensor->data; |                 lt.data = (uint8_t *) lt.ggml_tensor->data; | ||||||
|                 model_loader->load_data_for(lt); |                 model_loader->load_data_for(lt); | ||||||
|                 lt.ggml_tensor->data = lt.data; |                 lt.ggml_tensor->data = lt.data; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler