mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	ggml-cuda : use graph allocator (#2684)
use a different function for no_alloc to avoid breaking backwards compat, fixes lora remove 512 n_batch limit fixed 2048 batch size cleanup Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
		| @@ -289,7 +289,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|             params.n_batch = std::stoi(argv[i]); |             params.n_batch = std::stoi(argv[i]); | ||||||
|             params.n_batch = std::min(512, params.n_batch); |  | ||||||
|         } else if (arg == "--keep") { |         } else if (arg == "--keep") { | ||||||
|             if (++i >= argc) { |             if (++i >= argc) { | ||||||
|                 invalid_param = true; |                 invalid_param = true; | ||||||
|   | |||||||
							
								
								
									
										75
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										75
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -3887,13 +3887,13 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, | |||||||
| // rope == RoPE == rotary positional embedding | // rope == RoPE == rotary positional embedding | ||||||
| static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0, | static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0, | ||||||
|                                 const float p_delta, const int p_delta_rows, const float theta_scale) { |                                 const float p_delta, const int p_delta_rows, const float theta_scale) { | ||||||
|     const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); |     const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); | ||||||
|  |  | ||||||
|     if (col >= ncols) { |     if (col >= ncols) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const int row = blockDim.y*blockIdx.y + threadIdx.y; |     const int row = blockDim.x*blockIdx.x + threadIdx.x; | ||||||
|     const int i = row*ncols + col; |     const int i = row*ncols + col; | ||||||
|  |  | ||||||
|     const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); |     const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); | ||||||
| @@ -3965,8 +3965,8 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols, | |||||||
| } | } | ||||||
|  |  | ||||||
| static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { | static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { | ||||||
|     const int col = blockDim.x*blockIdx.x + threadIdx.x; |     const int col = blockDim.y*blockIdx.y + threadIdx.y; | ||||||
|     const int row = blockDim.y*blockIdx.y + threadIdx.y; |     const int row = blockDim.x*blockIdx.x + threadIdx.x; | ||||||
|  |  | ||||||
|     if (col >= ncols) { |     if (col >= ncols) { | ||||||
|         return; |         return; | ||||||
| @@ -3982,9 +3982,9 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int | |||||||
| // values are also not normalized to the maximum value by subtracting it in the exponential function | // values are also not normalized to the maximum value by subtracting it in the exponential function | ||||||
| // theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine | // theoretically these changes could cause problems with rounding error and arithmetic overflow but for LLaMa it seems to be fine | ||||||
| static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { | static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) { | ||||||
|     const int row = blockDim.y*blockIdx.y + threadIdx.y; |     const int row = blockDim.x*blockIdx.x + threadIdx.x; | ||||||
|     const int block_size = blockDim.x; |     const int block_size = blockDim.y; | ||||||
|     const int tid = threadIdx.x; |     const int tid = threadIdx.y; | ||||||
|  |  | ||||||
|     float tmp = 0.0; |     float tmp = 0.0; | ||||||
|  |  | ||||||
| @@ -4776,9 +4776,9 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons | |||||||
| static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, | static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, | ||||||
|                           const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { |                           const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { | ||||||
|     GGML_ASSERT(nrows % 2 == 0); |     GGML_ASSERT(nrows % 2 == 0); | ||||||
|     const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); |     const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1); | ||||||
|     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); |     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); | ||||||
|     const dim3 block_nums(num_blocks_x, nrows, 1); |     const dim3 block_nums(nrows, num_blocks_x, 1); | ||||||
|     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); |     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -4800,15 +4800,15 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const | |||||||
| } | } | ||||||
|  |  | ||||||
| static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { | static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { | ||||||
|     const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); |     const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1); | ||||||
|     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; |     const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; | ||||||
|     const dim3 block_nums(block_num_x, nrows_x, 1); |     const dim3 block_nums(nrows_x, block_num_x, 1); | ||||||
|     diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past); |     diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { | static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) { | ||||||
|     const dim3 block_dims(WARP_SIZE, 1, 1); |     const dim3 block_dims(1, WARP_SIZE, 1); | ||||||
|     const dim3 block_nums(1, nrows_x, 1); |     const dim3 block_nums(nrows_x, 1, 1); | ||||||
|     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x); |     soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x); | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -6313,7 +6313,7 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() { | |||||||
|     return extra; |     return extra; | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) { | void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) { | ||||||
|     if (scratch && g_scratch_size == 0) { |     if (scratch && g_scratch_size == 0) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| @@ -6322,14 +6322,19 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo | |||||||
|     if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { |     if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) { | ||||||
|         const ggml_op src0_op = tensor->src[0]->op; |         const ggml_op src0_op = tensor->src[0]->op; | ||||||
|         if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) { |         if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW || src0_op == GGML_OP_PERMUTE) { | ||||||
|             ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace); |             ggml_cuda_assign_buffers_impl(tensor->src[0], scratch, force_inplace, no_alloc); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) { |     if (tensor->op == GGML_OP_CPY && tensor->src[1]->backend == GGML_BACKEND_CPU) { | ||||||
|         ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace); |         ggml_cuda_assign_buffers_impl(tensor->src[1], scratch, force_inplace, no_alloc); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     tensor->backend = GGML_BACKEND_GPU; |     tensor->backend = GGML_BACKEND_GPU; | ||||||
|  |  | ||||||
|  |     if (scratch && no_alloc) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     struct ggml_tensor_extra_gpu * extra; |     struct ggml_tensor_extra_gpu * extra; | ||||||
|  |  | ||||||
|     const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || |     const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || | ||||||
| @@ -6381,16 +6386,48 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo | |||||||
|     tensor->extra = extra; |     tensor->extra = extra; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset) { | ||||||
|  |     if (g_scratch_size == 0) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |     if (g_scratch_buffer == nullptr) { | ||||||
|  |         CUDA_CHECK(cudaMalloc(&g_scratch_buffer, g_scratch_size)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     struct ggml_tensor_extra_gpu * extra = ggml_cuda_alloc_temp_tensor_extra(); | ||||||
|  |  | ||||||
|  |     const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || | ||||||
|  |         tensor->op == GGML_OP_VIEW; | ||||||
|  |  | ||||||
|  |     if (inplace && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) { | ||||||
|  |         struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src[0]->extra; | ||||||
|  |         char * src0_ddc = (char *) src0_extra->data_device[g_main_device]; | ||||||
|  |         size_t view_offset = 0; | ||||||
|  |         if (tensor->op == GGML_OP_VIEW) { | ||||||
|  |             memcpy(&view_offset, tensor->op_params, sizeof(size_t)); | ||||||
|  |         } | ||||||
|  |         extra->data_device[g_main_device] = src0_ddc + view_offset; | ||||||
|  |     } else { | ||||||
|  |         extra->data_device[g_main_device] = (char *) g_scratch_buffer + offset; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     tensor->extra = extra; | ||||||
|  | } | ||||||
|  |  | ||||||
| void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { | void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) { | ||||||
|     ggml_cuda_assign_buffers_impl(tensor, true, false); |     ggml_cuda_assign_buffers_impl(tensor, true, false, false); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor) { | ||||||
|  |     ggml_cuda_assign_buffers_impl(tensor, true, false, true); | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { | void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) { | ||||||
|     ggml_cuda_assign_buffers_impl(tensor, false, false); |     ggml_cuda_assign_buffers_impl(tensor, false, false, false); | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) { | void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) { | ||||||
|     ggml_cuda_assign_buffers_impl(tensor, false, true); |     ggml_cuda_assign_buffers_impl(tensor, false, true, false); | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_cuda_set_main_device(int main_device) { | void ggml_cuda_set_main_device(int main_device) { | ||||||
|   | |||||||
| @@ -16,9 +16,14 @@ GGML_API bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const str | |||||||
| GGML_API void   ggml_cuda_set_tensor_split(const float * tensor_split); | GGML_API void   ggml_cuda_set_tensor_split(const float * tensor_split); | ||||||
| GGML_API void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); | GGML_API void   ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); | ||||||
| GGML_API void   ggml_cuda_free_data(struct ggml_tensor * tensor); | GGML_API void   ggml_cuda_free_data(struct ggml_tensor * tensor); | ||||||
|  |  | ||||||
| GGML_API void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor); | GGML_API void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor); | ||||||
| GGML_API void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); | GGML_API void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); | ||||||
| GGML_API void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); | GGML_API void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); | ||||||
|  |  | ||||||
|  | GGML_API void   ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor); | ||||||
|  | GGML_API void   ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset); | ||||||
|  |  | ||||||
| GGML_API void   ggml_cuda_set_main_device(int main_device); | GGML_API void   ggml_cuda_set_main_device(int main_device); | ||||||
| GGML_API void   ggml_cuda_set_mul_mat_q(bool mul_mat_q); | GGML_API void   ggml_cuda_set_mul_mat_q(bool mul_mat_q); | ||||||
| GGML_API void   ggml_cuda_set_scratch_size(size_t scratch_size); | GGML_API void   ggml_cuda_set_scratch_size(size_t scratch_size); | ||||||
|   | |||||||
							
								
								
									
										239
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										239
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -10,13 +10,7 @@ | |||||||
|  |  | ||||||
| #include "ggml.h" | #include "ggml.h" | ||||||
|  |  | ||||||
| #if !defined(GGML_USE_CUBLAS) | #include "ggml-alloc.h" | ||||||
| #  include "ggml-alloc.h" |  | ||||||
| #  define LLAMA_USE_ALLOCATOR |  | ||||||
| #else |  | ||||||
| #  define LLAMA_USE_SCRATCH |  | ||||||
| #  define LLAMA_MAX_SCRATCH_BUFFERS 16 |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
| #  include "ggml-cuda.h" | #  include "ggml-cuda.h" | ||||||
| @@ -588,14 +582,6 @@ struct llama_state { | |||||||
|  |  | ||||||
| static llama_state g_state; | static llama_state g_state; | ||||||
|  |  | ||||||
| // |  | ||||||
| // memory sizes (calculated for n_batch == 512) |  | ||||||
| // |  | ||||||
|  |  | ||||||
| // computed for n_ctx == 2048 |  | ||||||
| // TODO: dynamically determine these sizes |  | ||||||
| //       needs modifications in ggml |  | ||||||
|  |  | ||||||
| // available llama models | // available llama models | ||||||
| enum e_model { | enum e_model { | ||||||
|     MODEL_UNKNOWN, |     MODEL_UNKNOWN, | ||||||
| @@ -610,76 +596,6 @@ enum e_model { | |||||||
| static const size_t kB = 1024; | static const size_t kB = 1024; | ||||||
| static const size_t MB = 1024*1024; | static const size_t MB = 1024*1024; | ||||||
|  |  | ||||||
| static std::map<e_model, size_t> MEM_REQ_SCRATCH0(int n_ctx) |  | ||||||
| { |  | ||||||
|     std::map<e_model, size_t> k_sizes = { |  | ||||||
|         { MODEL_3B,   ((size_t) n_ctx / 16ull +  92ull) * MB }, |  | ||||||
|         { MODEL_7B,   ((size_t) n_ctx / 16ull + 100ull) * MB }, |  | ||||||
|         { MODEL_13B,  ((size_t) n_ctx / 12ull + 120ull) * MB }, |  | ||||||
|         { MODEL_30B,  ((size_t) n_ctx /  9ull + 160ull) * MB }, |  | ||||||
|         { MODEL_65B,  ((size_t) n_ctx /  6ull + 256ull) * MB }, // guess |  | ||||||
|         { MODEL_70B,  ((size_t) n_ctx /  7ull + 164ull) * MB }, |  | ||||||
|     }; |  | ||||||
|     return k_sizes; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1() |  | ||||||
| { |  | ||||||
|     static std::map<e_model, size_t> k_sizes = { |  | ||||||
|         { MODEL_3B,  128ull * MB }, |  | ||||||
|         { MODEL_7B,  160ull * MB }, |  | ||||||
|         { MODEL_13B, 192ull * MB }, |  | ||||||
|         { MODEL_30B, 256ull * MB }, |  | ||||||
|         { MODEL_65B, 384ull * MB }, // guess |  | ||||||
|         { MODEL_70B, 304ull * MB }, |  | ||||||
|     }; |  | ||||||
|     return k_sizes; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // used to store the compute graph tensors + non-scratch data |  | ||||||
| static const std::map<e_model, size_t> & MEM_REQ_EVAL() |  | ||||||
| { |  | ||||||
|     static std::map<e_model, size_t> k_sizes = { |  | ||||||
|         { MODEL_3B,   8ull * MB }, |  | ||||||
|         { MODEL_7B,  10ull * MB }, |  | ||||||
|         { MODEL_13B, 12ull * MB }, |  | ||||||
|         { MODEL_30B, 16ull * MB }, |  | ||||||
|         { MODEL_65B, 24ull * MB }, // guess |  | ||||||
|         { MODEL_70B, 24ull * MB }, |  | ||||||
|     }; |  | ||||||
|     return k_sizes; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // amount of VRAM needed per batch size to hold temporary results |  | ||||||
| // the values for 3b are not derived from testing but instead chosen conservatively |  | ||||||
| static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_BASE() |  | ||||||
| { |  | ||||||
|     static std::map<e_model, size_t> k_sizes = { |  | ||||||
|         { MODEL_3B,   512ull * kB }, |  | ||||||
|         { MODEL_7B,   512ull * kB }, |  | ||||||
|         { MODEL_13B,  640ull * kB }, |  | ||||||
|         { MODEL_30B,  768ull * kB }, |  | ||||||
|         { MODEL_65B, 1280ull * kB }, |  | ||||||
|         { MODEL_70B, 1280ull * kB }, |  | ||||||
|     }; |  | ||||||
|     return k_sizes; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // amount of VRAM needed per batch size and context to hold temporary results |  | ||||||
| // the values for 3b are not derived from testing but instead chosen conservatively |  | ||||||
| static const std::map<e_model, size_t> & VRAM_REQ_SCRATCH_PER_CONTEXT() |  | ||||||
| { |  | ||||||
|     static std::map<e_model, size_t> k_sizes = { |  | ||||||
|         { MODEL_3B,  128ull }, |  | ||||||
|         { MODEL_7B,  128ull }, |  | ||||||
|         { MODEL_13B, 160ull }, |  | ||||||
|         { MODEL_30B, 208ull }, |  | ||||||
|         { MODEL_65B, 256ull }, |  | ||||||
|         { MODEL_70B, 256ull }, |  | ||||||
|     }; |  | ||||||
|     return k_sizes; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // default hparams (LLaMA 7B) | // default hparams (LLaMA 7B) | ||||||
| struct llama_hparams { | struct llama_hparams { | ||||||
|     uint32_t n_vocab     = 32000; |     uint32_t n_vocab     = 32000; | ||||||
| @@ -857,11 +773,9 @@ struct llama_context { | |||||||
|             ggml_metal_free(ctx_metal); |             ggml_metal_free(ctx_metal); | ||||||
|         } |         } | ||||||
| #endif | #endif | ||||||
| #ifdef LLAMA_USE_ALLOCATOR |  | ||||||
|         if (alloc) { |         if (alloc) { | ||||||
|             ggml_allocr_free(alloc); |             ggml_allocr_free(alloc); | ||||||
|         } |         } | ||||||
| #endif |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     std::mt19937 rng; |     std::mt19937 rng; | ||||||
| @@ -901,17 +815,8 @@ struct llama_context { | |||||||
|     // memory buffers used to evaluate the model |     // memory buffers used to evaluate the model | ||||||
|     llama_buffer buf_compute; |     llama_buffer buf_compute; | ||||||
|  |  | ||||||
| #ifdef LLAMA_USE_ALLOCATOR |  | ||||||
|     llama_buffer buf_alloc; |     llama_buffer buf_alloc; | ||||||
|     ggml_allocr * alloc = NULL; |     ggml_allocr * alloc = NULL; | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #ifdef LLAMA_USE_SCRATCH |  | ||||||
|     llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; |  | ||||||
|  |  | ||||||
|     int    buf_last = 0; |  | ||||||
|     size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #ifdef GGML_USE_METAL | #ifdef GGML_USE_METAL | ||||||
|     ggml_metal_context * ctx_metal = NULL; |     ggml_metal_context * ctx_metal = NULL; | ||||||
| @@ -920,37 +825,6 @@ struct llama_context { | |||||||
| #ifdef GGML_USE_MPI | #ifdef GGML_USE_MPI | ||||||
|     ggml_mpi_context * ctx_mpi = NULL; |     ggml_mpi_context * ctx_mpi = NULL; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|     void use_buf(struct ggml_context * ctx, int i) { // NOLINT |  | ||||||
| #if defined(LLAMA_USE_SCRATCH) |  | ||||||
|         size_t last_size = 0; |  | ||||||
|  |  | ||||||
|         if (i == -1) { |  | ||||||
|             last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); |  | ||||||
|         } else { |  | ||||||
|             auto & buf = buf_scratch[i]; |  | ||||||
|             last_size = ggml_set_scratch(ctx, { 0, buf.size, buf.data, }); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         if (buf_last >= 0) { |  | ||||||
|             buf_max_size[buf_last] = std::max(buf_max_size[buf_last], last_size); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         buf_last = i; |  | ||||||
| #else |  | ||||||
|         (void) i; |  | ||||||
|         (void) ctx; |  | ||||||
| #endif |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     size_t get_buf_max_mem(int i) { // NOLINT |  | ||||||
| #if defined(LLAMA_USE_SCRATCH) |  | ||||||
|         return buf_max_size[i]; |  | ||||||
| #else |  | ||||||
|         (void) i; |  | ||||||
|         return 0; |  | ||||||
| #endif |  | ||||||
|     } |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // | // | ||||||
| @@ -1620,7 +1494,6 @@ static void llama_model_load_internal( | |||||||
|  |  | ||||||
|     // prepare memory for the weights |     // prepare memory for the weights | ||||||
|     size_t vram_weights = 0; |     size_t vram_weights = 0; | ||||||
|     size_t vram_scratch = 0; |  | ||||||
|     { |     { | ||||||
|         const uint32_t n_embd     = hparams.n_embd; |         const uint32_t n_embd     = hparams.n_embd; | ||||||
|         const uint32_t n_embd_gqa = hparams.n_embd_gqa(); |         const uint32_t n_embd_gqa = hparams.n_embd_gqa(); | ||||||
| @@ -1701,13 +1574,6 @@ static void llama_model_load_internal( | |||||||
|             ctx_size + |             ctx_size + | ||||||
|             mmapped_size - vram_weights; // weights in VRAM not in memory |             mmapped_size - vram_weights; // weights in VRAM not in memory | ||||||
|  |  | ||||||
| #ifndef LLAMA_USE_ALLOCATOR |  | ||||||
|         mem_required += |  | ||||||
|             MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) + |  | ||||||
|             MEM_REQ_SCRATCH1().at(model.type) + |  | ||||||
|             MEM_REQ_EVAL().at(model.type); |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
|         // this is the memory required by one llama_state |         // this is the memory required by one llama_state | ||||||
|         const size_t mem_required_state = |         const size_t mem_required_state = | ||||||
|             scale*hparams.kv_size(); |             scale*hparams.kv_size(); | ||||||
| @@ -1715,24 +1581,7 @@ static void llama_model_load_internal( | |||||||
|         LLAMA_LOG_INFO("%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__, |         LLAMA_LOG_INFO("%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__, | ||||||
|                 mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); |                 mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); | ||||||
|  |  | ||||||
|         (void) vram_scratch; |  | ||||||
|         (void) n_batch; |         (void) n_batch; | ||||||
| #ifdef GGML_USE_CUBLAS |  | ||||||
|         if (low_vram) { |  | ||||||
|             LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); |  | ||||||
|             ggml_cuda_set_scratch_size(0); // disable scratch |  | ||||||
|         } else { |  | ||||||
|             const size_t vram_scratch_base = VRAM_REQ_SCRATCH_BASE().at(model.type); |  | ||||||
|             const size_t vram_scratch_per_context = VRAM_REQ_SCRATCH_PER_CONTEXT().at(model.type); |  | ||||||
|             vram_scratch = n_batch * (vram_scratch_base + n_ctx * vram_scratch_per_context); |  | ||||||
|             ggml_cuda_set_scratch_size(vram_scratch); |  | ||||||
|             if (n_gpu_layers > 0) { |  | ||||||
|                 LLAMA_LOG_INFO("%s: allocating batch_size x (%zd kB + n_ctx x %zd B) = %zd MB VRAM for the scratch buffer\n", |  | ||||||
|                         __func__, vram_scratch_base / kB, vram_scratch_per_context, |  | ||||||
|                         (vram_scratch + MB - 1) / MB); // round up |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| #endif // GGML_USE_CUBLAS |  | ||||||
|  |  | ||||||
| #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) | #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) | ||||||
|         const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); |         const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); | ||||||
| @@ -1769,8 +1618,8 @@ static void llama_model_load_internal( | |||||||
|  |  | ||||||
|         LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", |         LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", | ||||||
|                 __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); |                 __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); | ||||||
|         LLAMA_LOG_INFO("%s: total VRAM used: %zu MB\n", |         LLAMA_LOG_INFO("%s: VRAM used: %zu MB\n", | ||||||
|                 __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up |                 __func__, (vram_weights + vram_kv_cache + MB - 1) / MB); // round up | ||||||
| #else | #else | ||||||
|         (void) n_gpu_layers; |         (void) n_gpu_layers; | ||||||
| #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) | #endif // defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) | ||||||
| @@ -1875,9 +1724,7 @@ static struct ggml_cgraph * llama_build_graph( | |||||||
|         /*.no_alloc   =*/ false, |         /*.no_alloc   =*/ false, | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
| #ifdef LLAMA_USE_ALLOCATOR |  | ||||||
|     params.no_alloc = true; |     params.no_alloc = true; | ||||||
| #endif |  | ||||||
|  |  | ||||||
|     struct ggml_context * ctx0 = ggml_init(params); |     struct ggml_context * ctx0 = ggml_init(params); | ||||||
|  |  | ||||||
| @@ -1889,14 +1736,10 @@ static struct ggml_cgraph * llama_build_graph( | |||||||
|     if (tokens) { |     if (tokens) { | ||||||
|         struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); |         struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); | ||||||
|  |  | ||||||
| #ifdef LLAMA_USE_ALLOCATOR |  | ||||||
|         ggml_allocr_alloc(lctx.alloc, inp_tokens); |         ggml_allocr_alloc(lctx.alloc, inp_tokens); | ||||||
|         if (!ggml_allocr_is_measure(lctx.alloc)) { |         if (!ggml_allocr_is_measure(lctx.alloc)) { | ||||||
|             memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); |             memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); | ||||||
|         } |         } | ||||||
| #else |  | ||||||
|         memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); |  | ||||||
| #endif |  | ||||||
|         ggml_set_name(inp_tokens, "inp_tokens"); |         ggml_set_name(inp_tokens, "inp_tokens"); | ||||||
|  |  | ||||||
|         inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); |         inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); | ||||||
| @@ -1907,14 +1750,10 @@ static struct ggml_cgraph * llama_build_graph( | |||||||
|  |  | ||||||
|         inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); |         inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); | ||||||
|  |  | ||||||
| #ifdef LLAMA_USE_ALLOCATOR |  | ||||||
|         ggml_allocr_alloc(lctx.alloc, inpL); |         ggml_allocr_alloc(lctx.alloc, inpL); | ||||||
|         if (!ggml_allocr_is_measure(lctx.alloc)) { |         if (!ggml_allocr_is_measure(lctx.alloc)) { | ||||||
|             memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); |             memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); | ||||||
|         } |         } | ||||||
| #else |  | ||||||
|         memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); |  | ||||||
| #endif |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const int i_gpu_start = n_layer - n_gpu_layers; |     const int i_gpu_start = n_layer - n_gpu_layers; | ||||||
| @@ -1931,25 +1770,21 @@ static struct ggml_cgraph * llama_build_graph( | |||||||
|  |  | ||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
|     if (n_gpu_layers > n_layer) { |     if (n_gpu_layers > n_layer) { | ||||||
|         offload_func_nr = ggml_cuda_assign_buffers; |         offload_func_nr = ggml_cuda_assign_buffers_no_alloc; | ||||||
|     } |     } | ||||||
|     if (n_gpu_layers > n_layer + 1) { |     if (n_gpu_layers > n_layer + 1) { | ||||||
|         offload_func_v  = ggml_cuda_assign_buffers; |         offload_func_v  = ggml_cuda_assign_buffers_no_alloc; | ||||||
|     } |     } | ||||||
|     if (n_gpu_layers > n_layer + 2) { |     if (n_gpu_layers > n_layer + 2) { | ||||||
|         offload_func_kq = ggml_cuda_assign_buffers; |         offload_func_kq = ggml_cuda_assign_buffers_no_alloc; | ||||||
|     } |     } | ||||||
| #endif // GGML_USE_CUBLAS | #endif // GGML_USE_CUBLAS | ||||||
|  |  | ||||||
|     struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); |     struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); | ||||||
| #ifdef LLAMA_USE_ALLOCATOR |  | ||||||
|     ggml_allocr_alloc(lctx.alloc, KQ_scale); |     ggml_allocr_alloc(lctx.alloc, KQ_scale); | ||||||
|     if (!ggml_allocr_is_measure(lctx.alloc)) { |     if (!ggml_allocr_is_measure(lctx.alloc)) { | ||||||
|         ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); |         ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); | ||||||
|     } |     } | ||||||
| #else |  | ||||||
|     ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); |  | ||||||
| #endif |  | ||||||
|     ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); |     ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); | ||||||
|  |  | ||||||
|     for (int il = 0; il < n_layer; ++il) { |     for (int il = 0; il < n_layer; ++il) { | ||||||
| @@ -1959,14 +1794,12 @@ static struct ggml_cgraph * llama_build_graph( | |||||||
|  |  | ||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
|         if (il >= i_gpu_start) { |         if (il >= i_gpu_start) { | ||||||
|             offload_func = ggml_cuda_assign_buffers; |             offload_func = ggml_cuda_assign_buffers_no_alloc; | ||||||
|         } |         } | ||||||
| #endif // GGML_USE_CUBLAS | #endif // GGML_USE_CUBLAS | ||||||
|  |  | ||||||
|         struct ggml_tensor * inpSA = inpL; |         struct ggml_tensor * inpSA = inpL; | ||||||
|  |  | ||||||
|         lctx.use_buf(ctx0, 0); |  | ||||||
|  |  | ||||||
|         // norm |         // norm | ||||||
|         { |         { | ||||||
|             cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); |             cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); | ||||||
| @@ -2104,8 +1937,6 @@ static struct ggml_cgraph * llama_build_graph( | |||||||
|             ggml_set_name(cur, "result_wo"); |             ggml_set_name(cur, "result_wo"); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         lctx.use_buf(ctx0, 1); |  | ||||||
|  |  | ||||||
|         struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); |         struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); | ||||||
|         offload_func(inpFF); |         offload_func(inpFF); | ||||||
|         ggml_set_name(inpFF, "inpFF"); |         ggml_set_name(inpFF, "inpFF"); | ||||||
| @@ -2160,8 +1991,6 @@ static struct ggml_cgraph * llama_build_graph( | |||||||
|         inpL = cur; |         inpL = cur; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     lctx.use_buf(ctx0, 0); |  | ||||||
|  |  | ||||||
|     // norm |     // norm | ||||||
|     { |     { | ||||||
|         cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); |         cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); | ||||||
| @@ -2178,8 +2007,6 @@ static struct ggml_cgraph * llama_build_graph( | |||||||
|     cur = ggml_mul_mat(ctx0, model.output, cur); |     cur = ggml_mul_mat(ctx0, model.output, cur); | ||||||
|     ggml_set_name(cur, "result_output"); |     ggml_set_name(cur, "result_output"); | ||||||
|  |  | ||||||
|     lctx.use_buf(ctx0, -1); |  | ||||||
|  |  | ||||||
|     // logits -> probs |     // logits -> probs | ||||||
|     //cur = ggml_soft_max_inplace(ctx0, cur); |     //cur = ggml_soft_max_inplace(ctx0, cur); | ||||||
|  |  | ||||||
| @@ -2189,15 +2016,6 @@ static struct ggml_cgraph * llama_build_graph( | |||||||
|         mem_per_token = ggml_used_mem(ctx0)/N; |         mem_per_token = ggml_used_mem(ctx0)/N; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| #if 0 |  | ||||||
|     LLAMA_LOG_INFO("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__, |  | ||||||
|             ggml_used_mem(ctx0)/1024.0/1024.0, |  | ||||||
|             lctx.get_buf_max_mem(0)/1024.0/1024.0, |  | ||||||
|             lctx.get_buf_max_mem(1)/1024.0/1024.0, |  | ||||||
|             lctx.work_buffer.size()/1024.0/1024.0, |  | ||||||
|             n_past, N); |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
|     ggml_free(ctx0); |     ggml_free(ctx0); | ||||||
|  |  | ||||||
|     return gf; |     return gf; | ||||||
| @@ -2248,14 +2066,26 @@ static bool llama_eval_internal( | |||||||
|     const int64_t n_embd      = hparams.n_embd; |     const int64_t n_embd      = hparams.n_embd; | ||||||
|     const int64_t n_vocab     = hparams.n_vocab; |     const int64_t n_vocab     = hparams.n_vocab; | ||||||
|  |  | ||||||
| #ifdef LLAMA_USE_ALLOCATOR |  | ||||||
|     ggml_allocr_reset(lctx.alloc); |     ggml_allocr_reset(lctx.alloc); | ||||||
| #endif |  | ||||||
|  |  | ||||||
|     ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past); |     ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past); | ||||||
|  |  | ||||||
| #ifdef LLAMA_USE_ALLOCATOR |  | ||||||
|     ggml_allocr_alloc_graph(lctx.alloc, gf); |     ggml_allocr_alloc_graph(lctx.alloc, gf); | ||||||
|  |  | ||||||
|  | #ifdef GGML_USE_CUBLAS | ||||||
|  |     for (int i = 0; i < gf->n_leafs; i++) { | ||||||
|  |         ggml_tensor * node = gf->leafs[i]; | ||||||
|  |         if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { | ||||||
|  |             ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < gf->n_nodes; i++) { | ||||||
|  |         ggml_tensor * node = gf->nodes[i]; | ||||||
|  |         if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { | ||||||
|  |             ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); | ||||||
|  |         } | ||||||
|  |     } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|     // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); |     // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); | ||||||
| @@ -4319,7 +4149,6 @@ struct llama_context * llama_new_context_with_model( | |||||||
|             ctx->embedding.resize(hparams.n_embd); |             ctx->embedding.resize(hparams.n_embd); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| #ifdef LLAMA_USE_ALLOCATOR |  | ||||||
|         { |         { | ||||||
|             static const size_t tensor_alignment = 32; |             static const size_t tensor_alignment = 32; | ||||||
|             // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data |             // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data | ||||||
| @@ -4350,13 +4179,6 @@ struct llama_context * llama_new_context_with_model( | |||||||
|  |  | ||||||
|             LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); |             LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); | ||||||
|  |  | ||||||
|             // debug - for comparison with scratch buffer |  | ||||||
|             //size_t prev_req = |  | ||||||
|             //    MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) + |  | ||||||
|             //    MEM_REQ_SCRATCH1().at(ctx->model.type) + |  | ||||||
|             //    MEM_REQ_EVAL().at(ctx->model.type); |  | ||||||
|             //LLAMA_LOG_INFO("%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0); |  | ||||||
|  |  | ||||||
|             // recreate allocator with exact memory requirements |             // recreate allocator with exact memory requirements | ||||||
|             ggml_allocr_free(ctx->alloc); |             ggml_allocr_free(ctx->alloc); | ||||||
|  |  | ||||||
| @@ -4366,16 +4188,17 @@ struct llama_context * llama_new_context_with_model( | |||||||
|             if (ctx->ctx_metal) { |             if (ctx->ctx_metal) { | ||||||
|                 ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); |                 ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); | ||||||
|             } |             } | ||||||
|  | #endif | ||||||
|  | #ifdef GGML_USE_CUBLAS | ||||||
|  |             if (params.low_vram) { | ||||||
|  |                 LLAMA_LOG_INFO("%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__); | ||||||
|  |                 ggml_cuda_set_scratch_size(0); // disable scratch | ||||||
|  |             } else { | ||||||
|  |                 ggml_cuda_set_scratch_size(alloc_size); | ||||||
|  |                 LLAMA_LOG_INFO("%s: VRAM scratch buffer: %.2f MB\n", __func__, alloc_size / 1024.0 / 1024.0); | ||||||
|  |             } | ||||||
| #endif | #endif | ||||||
|         } |         } | ||||||
| #else |  | ||||||
|         ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #ifdef LLAMA_USE_SCRATCH |  | ||||||
|         ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); |  | ||||||
|         ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); |  | ||||||
| #endif |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
| #ifdef GGML_USE_METAL | #ifdef GGML_USE_METAL | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren