mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-07 09:57:00 +00:00
allocators wip
renamed ggml_backend functions changed ggml_buffer and ggml_backend to always be used as pointers rename ggml_tensor::params -> op_params
This commit is contained in:
183
llama.cpp
183
llama.cpp
@@ -172,7 +172,7 @@ struct llama_kv_cache {
|
||||
|
||||
struct ggml_context * ctx = NULL;
|
||||
|
||||
ggml_buffer buf;
|
||||
ggml_buffer * buf;
|
||||
|
||||
int n; // number of tokens currently in the cache
|
||||
|
||||
@@ -225,29 +225,29 @@ struct llama_model {
|
||||
llama_vocab vocab;
|
||||
|
||||
// backends
|
||||
ggml_backend backend_cpu;
|
||||
ggml_buffer buf_cpu;
|
||||
ggml_backend * backend_cpu = NULL;
|
||||
ggml_buffer * buf_cpu = NULL;
|
||||
ggml_context * ctx_cpu = NULL;
|
||||
#ifdef GGML_USE_CUDA
|
||||
ggml_backend backend_cuda;
|
||||
ggml_buffer buf_cuda;
|
||||
ggml_backend * backend_cuda = NULL;
|
||||
ggml_buffer * buf_cuda = NULL;
|
||||
ggml_context * ctx_cuda = NULL;
|
||||
#endif
|
||||
|
||||
// backend assigned to each layer
|
||||
ggml_backend * backend_input = NULL;
|
||||
ggml_backend * backend_output = NULL;
|
||||
ggml_backend * backend_inp = NULL;
|
||||
ggml_backend * backend_out = NULL;
|
||||
std::vector<ggml_backend *> backend_layers;
|
||||
|
||||
~llama_model() {
|
||||
if (ctx_cpu) {
|
||||
ggml_free(ctx_cpu);
|
||||
ggml_backend_free_buffer(&buf_cpu);
|
||||
ggml_buffer_free(buf_cpu);
|
||||
}
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (ctx_cuda) {
|
||||
ggml_free(ctx_cuda);
|
||||
ggml_backend_free_buffer(&buf_cuda);
|
||||
ggml_buffer_free(buf_cuda);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -286,9 +286,9 @@ struct llama_context {
|
||||
std::vector<float> embedding;
|
||||
|
||||
// memory buffers used to evaluate the model
|
||||
ggml_buffer buf_compute_cpu = {};
|
||||
ggml_buffer * buf_compute_cpu;
|
||||
#ifdef GGML_USE_CUDA
|
||||
ggml_buffer buf_compute_cuda = {};
|
||||
ggml_buffer * buf_compute_cuda;
|
||||
#endif
|
||||
|
||||
// input tensors
|
||||
@@ -300,8 +300,19 @@ struct llama_context {
|
||||
struct ggml_tensor * graph_embeddings_out = nullptr;
|
||||
|
||||
// buffers to store the inputs and outputs of the graphs
|
||||
ggml_buffer buf_input = {};
|
||||
ggml_buffer buf_output = {};
|
||||
ggml_buffer * buf_input;
|
||||
ggml_buffer * buf_output;
|
||||
|
||||
/*
|
||||
~llama_context() {
|
||||
if (model_owner) {
|
||||
delete &model;
|
||||
}
|
||||
if (buf_compute_cpu) {
|
||||
ggml_buffer_free(buf_compute_cpu);
|
||||
}
|
||||
}
|
||||
*/
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -601,9 +612,6 @@ struct llama_model_loader {
|
||||
void load_all_data(llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) {
|
||||
size_t data_size = 0;
|
||||
size_t lock_size = 0;
|
||||
for (const llama_load_tensor & lt : tensors_map.tensors) {
|
||||
data_size += lt.size;
|
||||
}
|
||||
|
||||
if (use_mmap) {
|
||||
mapping.reset(new llama_mmap(&file_loader->file, false, ggml_is_numa()));
|
||||
@@ -613,14 +621,28 @@ struct llama_model_loader {
|
||||
}
|
||||
|
||||
size_t done_size = 0;
|
||||
std::vector<uint8_t> tmp_buf;
|
||||
std::vector<uint8_t> load_buf;
|
||||
size_t load_buf_size = 0;
|
||||
for (llama_load_tensor & lt : tensors_map.tensors) {
|
||||
bool is_cpu = lt.ggml_tensor->backend == model->backend_cpu;
|
||||
if (!use_mmap && !is_cpu) {
|
||||
load_buf_size = std::max(load_buf_size, lt.size);
|
||||
}
|
||||
data_size += lt.size;
|
||||
}
|
||||
if (load_buf_size > 0) {
|
||||
load_buf.resize(load_buf_size);
|
||||
// may improve CUDA loading speed without mmap
|
||||
//ggml_cuda_host_register(load_buf.data(), load_buf.size());
|
||||
}
|
||||
|
||||
for (llama_load_tensor & lt : tensors_map.tensors) {
|
||||
if (progress_callback) {
|
||||
progress_callback((float) done_size / data_size, progress_callback_user_data);
|
||||
}
|
||||
LLAMA_ASSERT(lt.ggml_tensor); // unused tensors should have been caught by load_data already
|
||||
|
||||
bool is_cpu = lt.ggml_tensor->backend == &model->backend_cpu;
|
||||
bool is_cpu = lt.ggml_tensor->backend == model->backend_cpu;
|
||||
|
||||
// select buffer to load data into
|
||||
if (!use_mmap) {
|
||||
@@ -628,8 +650,7 @@ struct llama_model_loader {
|
||||
lt.data = (uint8_t *) lt.ggml_tensor->data;
|
||||
} else {
|
||||
// read to temporary buffer
|
||||
tmp_buf.resize(lt.size);
|
||||
lt.data = (uint8_t *) tmp_buf.data();
|
||||
lt.data = (uint8_t *) load_buf.data();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -645,7 +666,7 @@ struct llama_model_loader {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ggml_backend_set_tensor(lt.ggml_tensor, lt.data, 0, lt.size);
|
||||
ggml_backend_tensor_set(lt.ggml_tensor, lt.data, 0, lt.size);
|
||||
if (use_mmap) {
|
||||
// hint the OS that we don't need the data anymore
|
||||
// TODO: this may be a bad idea with devices that use the system memory (Metal?)
|
||||
@@ -655,6 +676,9 @@ struct llama_model_loader {
|
||||
|
||||
done_size += lt.size;
|
||||
}
|
||||
//if (load_buf_size > 0) {
|
||||
// ggml_cuda_host_unregister(load_buf.data());
|
||||
//}
|
||||
}
|
||||
|
||||
void load_data_for(llama_load_tensor & lt) {
|
||||
@@ -701,11 +725,11 @@ static bool kv_cache_init(
|
||||
|
||||
size_t size = 2u*n_elements*ggml_type_size(wtype) + 2u*MB;
|
||||
|
||||
cache.buf = ggml_backend_alloc_buffer(backend, size, 2);
|
||||
cache.buf = ggml_buffer_alloc(backend, size, 2);
|
||||
cache.n = 0;
|
||||
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = &cache.buf;
|
||||
params.buffer = cache.buf;
|
||||
|
||||
cache.ctx = ggml_init(params);
|
||||
|
||||
@@ -771,7 +795,7 @@ void llama_backend_init(bool numa) {
|
||||
// needed to initialize f16 tables
|
||||
{
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = {0};
|
||||
params.buffer = NULL;
|
||||
struct ggml_context * ctx = ggml_init(params);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
@@ -940,30 +964,30 @@ static void llama_model_load_internal(
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
model.backend_cpu = ggml_backend_cpu_init();
|
||||
ggml_backend * backend_gpu = &model.backend_cpu; // hack until we have a proper backend selection
|
||||
ggml_backend * backend_gpu = model.backend_cpu; // hack until we have a proper backend selection
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (n_gpu_layers > 0) {
|
||||
model.backend_cuda = ggml_backend_cuda_init();
|
||||
backend_gpu = &model.backend_cuda;
|
||||
backend_gpu = model.backend_cuda;
|
||||
}
|
||||
#endif
|
||||
|
||||
// assign splits to the backends
|
||||
const int i_gpu_start = std::max(0, (int)n_layer - n_gpu_layers);
|
||||
model.backend_input = n_gpu_layers > (int)n_layer ? backend_gpu : &model.backend_cpu;
|
||||
model.backend_output = n_gpu_layers > 0 ? backend_gpu : &model.backend_cpu;
|
||||
model.backend_inp = n_gpu_layers > (int)n_layer ? backend_gpu : model.backend_cpu;
|
||||
model.backend_out = n_gpu_layers > 0 ? backend_gpu : model.backend_cpu;
|
||||
model.backend_layers.resize(n_layer);
|
||||
std::fill(model.backend_layers.begin(), model.backend_layers.begin() + i_gpu_start, &model.backend_cpu);
|
||||
std::fill(model.backend_layers.begin(), model.backend_layers.begin() + i_gpu_start, model.backend_cpu);
|
||||
std::fill(model.backend_layers.begin() + i_gpu_start, model.backend_layers.end(), backend_gpu);
|
||||
|
||||
// calculate the size of each context
|
||||
std::unordered_map<struct ggml_backend *, size_t> ctx_sizes;
|
||||
for (const llama_load_tensor & lt : ml->tensors_map.tensors) {
|
||||
if (lt.name == "tok_embeddings.weight") {
|
||||
ctx_sizes[model.backend_input] += lt.size;
|
||||
ctx_sizes[model.backend_inp] += lt.size;
|
||||
}
|
||||
else if (lt.name == "norm.weight" || lt.name == "output.weight") {
|
||||
ctx_sizes[model.backend_output] += lt.size;
|
||||
ctx_sizes[model.backend_out] += lt.size;
|
||||
}
|
||||
else {
|
||||
// parse layer number from name
|
||||
@@ -980,14 +1004,14 @@ static void llama_model_load_internal(
|
||||
// TODO: generalize support for mmap
|
||||
size_t mmap_size = 0;
|
||||
if (ml->use_mmap) {
|
||||
mmap_size = ctx_sizes[&model.backend_cpu];
|
||||
ctx_sizes[&model.backend_cpu] = 0;
|
||||
mmap_size = ctx_sizes[model.backend_cpu];
|
||||
ctx_sizes[model.backend_cpu] = 0;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: ggml ctx sizes:\n", __func__);
|
||||
for (const auto & it : ctx_sizes) {
|
||||
fprintf(stderr, "%8s = %7.2f MB", ggml_backend_name(it.first), it.second / 1024.0 / 1024.0);
|
||||
if (it.first == &model.backend_cpu && ml->use_mmap) {
|
||||
if (it.first == model.backend_cpu && ml->use_mmap) {
|
||||
fprintf(stderr, " + %7.2f MB (mmap)", mmap_size / 1024.0 / 1024.0);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
@@ -996,10 +1020,10 @@ static void llama_model_load_internal(
|
||||
// create the buffers and contexts
|
||||
{
|
||||
size_t cpu_num_tensors = ml->tensors_map.tensors.size();
|
||||
size_t ctx_size = ctx_sizes[&model.backend_cpu];
|
||||
model.buf_cpu = ggml_backend_alloc_buffer(&model.backend_cpu, ctx_size, cpu_num_tensors);
|
||||
size_t ctx_size = ctx_sizes[model.backend_cpu];
|
||||
model.buf_cpu = ggml_buffer_alloc(model.backend_cpu, ctx_size, cpu_num_tensors);
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = &model.buf_cpu;
|
||||
params.buffer = model.buf_cpu;
|
||||
params.no_alloc = ml->use_mmap;
|
||||
model.ctx_cpu = ggml_init(params);
|
||||
if (!model.ctx_cpu) {
|
||||
@@ -1011,10 +1035,10 @@ static void llama_model_load_internal(
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (n_gpu_layers > 0) {
|
||||
size_t gpu_num_tensors = ml->tensors_map.tensors.size();
|
||||
size_t ctx_size = ctx_sizes[&model.backend_cuda];
|
||||
model.buf_cuda = ggml_backend_alloc_buffer(&model.backend_cuda, ctx_size, gpu_num_tensors);
|
||||
size_t ctx_size = ctx_sizes[model.backend_cuda];
|
||||
model.buf_cuda = ggml_buffer_alloc(model.backend_cuda, ctx_size, gpu_num_tensors);
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = &model.buf_cuda;
|
||||
params.buffer = model.buf_cuda;
|
||||
model.ctx_cuda = ggml_init(params);
|
||||
if (!model.ctx_cuda) {
|
||||
throw std::runtime_error(format("ggml_init() failed for CUDA backend"));
|
||||
@@ -1025,9 +1049,9 @@ static void llama_model_load_internal(
|
||||
|
||||
// TODO: clean this
|
||||
ggml_context * ctx_input = model.ctx_cpu;
|
||||
if (model.backend_input == backend_gpu) ctx_input = ctx_gpu;
|
||||
if (model.backend_inp == backend_gpu) ctx_input = ctx_gpu;
|
||||
ggml_context * ctx_output = model.ctx_cpu;
|
||||
if (model.backend_output == backend_gpu) ctx_output = ctx_gpu;
|
||||
if (model.backend_out == backend_gpu) ctx_output = ctx_gpu;
|
||||
std::vector<ggml_context *> ctx_layers(n_layer, model.ctx_cpu);
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
if (model.backend_layers[i] == backend_gpu) {
|
||||
@@ -1181,18 +1205,18 @@ static ggml_graph_splits llama_build_graph(
|
||||
// initialize contexts for every backend
|
||||
|
||||
struct ggml_context * ctx_cpu = nullptr;
|
||||
if (lctx.buf_compute_cpu.mem_size > 0) {
|
||||
if (lctx.buf_compute_cpu != nullptr) {
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = &lctx.buf_compute_cpu;
|
||||
params.buffer = lctx.buf_compute_cpu;
|
||||
params.compute_type = compute_type;
|
||||
ctx_cpu = ggml_init(params);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
struct ggml_context * ctx_cuda = nullptr;
|
||||
if (lctx.buf_compute_cuda.mem_size > 0) {
|
||||
if (lctx.buf_compute_cuda != nullptr) {
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = &lctx.buf_compute_cuda;
|
||||
params.buffer = lctx.buf_compute_cuda;
|
||||
params.compute_type = compute_type;
|
||||
ctx_cuda = ggml_init(params);
|
||||
}
|
||||
@@ -1204,26 +1228,30 @@ static ggml_graph_splits llama_build_graph(
|
||||
struct ggml_context * ctx_o = nullptr;
|
||||
struct ggml_context * ctx_kv = nullptr;
|
||||
|
||||
if (lctx.model.backend_input == &lctx.model.backend_cpu) ctx_i = ctx_cpu;
|
||||
if (lctx.model.backend_output == &lctx.model.backend_cpu) ctx_o = ctx_cpu;
|
||||
if (lctx.model.backend_inp == lctx.model.backend_cpu) ctx_i = ctx_cpu;
|
||||
if (lctx.model.backend_out == lctx.model.backend_cpu) ctx_o = ctx_cpu;
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (lctx.model.backend_input == &lctx.model.backend_cuda) ctx_i = ctx_cuda;
|
||||
if (lctx.model.backend_output == &lctx.model.backend_cuda) ctx_o = ctx_cuda;
|
||||
if (lctx.model.backend_inp == lctx.model.backend_cuda) ctx_i = ctx_cuda;
|
||||
if (lctx.model.backend_out == lctx.model.backend_cuda) ctx_o = ctx_cuda;
|
||||
#endif
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
if (lctx.model.backend_layers[il] == &lctx.model.backend_cpu) ctx_ls[il] = ctx_cpu;
|
||||
if (lctx.model.backend_layers[il] == lctx.model.backend_cpu) ctx_ls[il] = ctx_cpu;
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (lctx.model.backend_layers[il] == &lctx.model.backend_cuda) ctx_ls[il] = ctx_cuda;
|
||||
if (lctx.model.backend_layers[il] == lctx.model.backend_cuda) ctx_ls[il] = ctx_cuda;
|
||||
#endif
|
||||
}
|
||||
if (lctx.backend_kv == &lctx.model.backend_cpu) ctx_kv = ctx_cpu;
|
||||
if (lctx.backend_kv == lctx.model.backend_cpu) ctx_kv = ctx_cpu;
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (lctx.backend_kv == &lctx.model.backend_cuda) ctx_kv = ctx_cuda;
|
||||
if (lctx.backend_kv == lctx.model.backend_cuda) ctx_kv = ctx_cuda;
|
||||
#endif
|
||||
|
||||
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
// reuse the scale tensor for all layers since it requires a memory transfer
|
||||
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx_kv, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");
|
||||
|
||||
if (embeddings_input) {
|
||||
// use embeddings as input
|
||||
struct ggml_tensor * embd_in = lctx.graph_embeddings_in;
|
||||
@@ -1236,10 +1264,6 @@ static ggml_graph_splits llama_build_graph(
|
||||
inpL = ggml_get_rows(ctx_i, model.tok_embeddings, token_in);
|
||||
}
|
||||
|
||||
// reuse the scale tensor for all layers since it requires a memory transfer
|
||||
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx_kv, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");
|
||||
|
||||
struct ggml_tensor * cur = nullptr;
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_context * ctx_l = ctx_ls[il];
|
||||
@@ -1540,16 +1564,16 @@ static bool llama_eval_internal(
|
||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||
n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads;
|
||||
ggml_backend_cpu_set_n_threads(const_cast<ggml_backend*>(&model.backend_cpu), n_threads);
|
||||
ggml_backend_cpu_set_n_threads(const_cast<ggml_backend*>(model.backend_cpu), n_threads);
|
||||
|
||||
struct ggml_graph_splits splits = llama_build_graph(lctx, N, n_past, embd_input);
|
||||
|
||||
if (tokens != nullptr) {
|
||||
// copy the tokens to the input tensor
|
||||
ggml_backend_set_tensor_async(lctx.graph_tokens_in, tokens, 0, N*ggml_element_size(lctx.graph_tokens_in));
|
||||
ggml_backend_tensor_set_async(lctx.graph_tokens_in, tokens, 0, N*ggml_element_size(lctx.graph_tokens_in));
|
||||
} else {
|
||||
// copy the embeddings to the input tensor
|
||||
ggml_backend_set_tensor_async(lctx.graph_embeddings_in, embd, 0, N*n_embd*ggml_element_size(lctx.graph_embeddings_in));
|
||||
ggml_backend_tensor_set_async(lctx.graph_embeddings_in, embd, 0, N*n_embd*ggml_element_size(lctx.graph_embeddings_in));
|
||||
}
|
||||
|
||||
// run the computation
|
||||
@@ -1577,11 +1601,11 @@ static bool llama_eval_internal(
|
||||
|
||||
if (lctx.logits_all) {
|
||||
logits_out.resize(n_vocab * N);
|
||||
ggml_backend_get_tensor_async(lctx.graph_logits, logits_out.data(), 0, N*n_vocab*sizeof(float));
|
||||
ggml_backend_tensor_get_async(lctx.graph_logits, logits_out.data(), 0, N*n_vocab*sizeof(float));
|
||||
} else {
|
||||
// return result for just the last token
|
||||
logits_out.resize(n_vocab);
|
||||
ggml_backend_get_tensor_async(lctx.graph_logits, logits_out.data(), 0, n_vocab*sizeof(float));
|
||||
ggml_backend_tensor_get_async(lctx.graph_logits, logits_out.data(), 0, n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1589,13 +1613,13 @@ static bool llama_eval_internal(
|
||||
if (!lctx.embedding.empty()) {
|
||||
auto & embedding_out = lctx.embedding;
|
||||
embedding_out.resize(n_embd);
|
||||
ggml_backend_get_tensor_async(lctx.graph_embeddings_out, embedding_out.data(), 0, n_embd*sizeof(float));
|
||||
ggml_backend_tensor_get_async(lctx.graph_embeddings_out, embedding_out.data(), 0, n_embd*sizeof(float));
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
// wait for the async copy to finish
|
||||
if (lctx.model.n_gpu_layers > 0) {
|
||||
ggml_backend_synchronize(const_cast<ggml_backend*>(&lctx.model.backend_cuda));
|
||||
ggml_backend_synchronize(const_cast<ggml_backend*>(lctx.model.backend_cuda));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -2063,7 +2087,7 @@ void llama_sample_classifier_free_guidance(
|
||||
struct llama_context * guidance_ctx,
|
||||
float scale,
|
||||
float smooth_factor) {
|
||||
int64_t t_start_sample_us = t_start_sample_us = ggml_time_us();
|
||||
int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
assert(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
@@ -2608,13 +2632,13 @@ struct llama_context * llama_new_context_with_model(
|
||||
|
||||
// TODO: choose backend depending on n_layers/low_vram
|
||||
#ifdef GGML_USE_CUDA
|
||||
if ((uint32_t)params.n_gpu_layers >= model->hparams.n_layer/2) {
|
||||
ctx->backend_kv = &model->backend_cuda;
|
||||
if ((uint32_t)params.n_gpu_layers >= model->hparams.n_layer/2 && !params.low_vram) {
|
||||
ctx->backend_kv = model->backend_cuda;
|
||||
} else {
|
||||
ctx->backend_kv = &model->backend_cpu;
|
||||
ctx->backend_kv = model->backend_cpu;
|
||||
}
|
||||
#else
|
||||
ctx->backend_kv = &model->backend_cpu;
|
||||
ctx->backend_kv = model->backend_cpu;
|
||||
#endif
|
||||
|
||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
@@ -2639,10 +2663,12 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
|
||||
// TODO: size the buffers more accurately - depends on improved memory management
|
||||
ctx->buf_compute_cpu = ggml_backend_alloc_buffer(&model->backend_cpu, MEM_REQ_EVAL().at(ctx->model.type), 2048);
|
||||
ctx->buf_compute_cpu = ggml_buffer_alloc(model->backend_cpu, MEM_REQ_EVAL().at(ctx->model.type), 2048);
|
||||
// TODO: pinned memory for faster host-device transfers
|
||||
//ggml_cuda_host_register(*(void**)ctx->buf_compute_cpu.backend_buffer, MEM_REQ_EVAL().at(ctx->model.type) + 128*2048);
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (params.n_gpu_layers > 0) {
|
||||
ctx->buf_compute_cuda = ggml_backend_alloc_buffer(&model->backend_cuda, MEM_REQ_EVAL().at(ctx->model.type), 2048);
|
||||
ctx->buf_compute_cuda = ggml_buffer_alloc(model->backend_cuda, MEM_REQ_EVAL().at(ctx->model.type), 2048);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -2653,10 +2679,10 @@ struct llama_context * llama_new_context_with_model(
|
||||
buf_input_size += hparams.n_ctx * ggml_type_size(GGML_TYPE_F32); // input tokens
|
||||
// TODO: input embeddings should be optional to save memory
|
||||
buf_input_size += hparams.n_embd * hparams.n_ctx * ggml_type_size(GGML_TYPE_F32); // input embeddings
|
||||
ctx->buf_input = ggml_backend_alloc_buffer(model->backend_input, buf_input_size, 2);
|
||||
ctx->buf_input = ggml_buffer_alloc(model->backend_inp, buf_input_size, 2);
|
||||
|
||||
struct ggml_init_params ggml_params = ggml_init_params_default();
|
||||
ggml_params.buffer = &ctx->buf_input;
|
||||
ggml_params.buffer = ctx->buf_input;
|
||||
ggml_context * ctx0 = ggml_init(ggml_params);
|
||||
|
||||
ctx->graph_tokens_in = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, hparams.n_ctx);
|
||||
@@ -2677,10 +2703,10 @@ struct llama_context * llama_new_context_with_model(
|
||||
if (params.embedding) {
|
||||
buf_output_size += hparams.n_embd * ggml_type_size(GGML_TYPE_F32);
|
||||
}
|
||||
ctx->buf_output = ggml_backend_alloc_buffer(model->backend_output, buf_output_size, 2);
|
||||
ctx->buf_output = ggml_buffer_alloc(model->backend_out, buf_output_size, 2);
|
||||
|
||||
struct ggml_init_params ggml_params = ggml_init_params_default();
|
||||
ggml_params.buffer = &ctx->buf_output;
|
||||
ggml_params.buffer = ctx->buf_output;
|
||||
ggml_context * ctx0 = ggml_init(ggml_params);
|
||||
|
||||
ctx->graph_logits = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_vocab, params.logits_all ? hparams.n_ctx : 1);
|
||||
@@ -2706,7 +2732,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: layer backends: ", __func__);
|
||||
fprintf(stderr, "input: %s, ", ggml_backend_name(ctx->model.backend_input));
|
||||
fprintf(stderr, "input: %s, ", ggml_backend_name(ctx->model.backend_inp));
|
||||
|
||||
int start = 0;
|
||||
struct ggml_backend * prev_backend = ctx->model.backend_layers[0];
|
||||
@@ -2721,7 +2747,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
prev_backend = ctx->model.backend_layers[i];
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "output: %s, ", ggml_backend_name(ctx->model.backend_output));
|
||||
fprintf(stderr, "output: %s, ", ggml_backend_name(ctx->model.backend_out));
|
||||
fprintf(stderr, "kv: %s\n", ggml_backend_name(ctx->backend_kv));
|
||||
|
||||
#ifdef GGML_USE_MPI
|
||||
@@ -2753,6 +2779,7 @@ struct llama_context * llama_init_from_file(
|
||||
}
|
||||
|
||||
void llama_free(struct llama_context * ctx) {
|
||||
// TODO: free buffers - move this to destructor like llama_model
|
||||
if (ctx->model_owner) {
|
||||
delete &ctx->model;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user