mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-03 09:22:01 +00:00
tests : fix overflow and memory leaks in test-model-random
* tests : fix integer types in test-model-random
This commit is contained in:
@@ -8847,9 +8847,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
|
||||
};
|
||||
|
||||
struct llm_build_mamba : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
|
||||
llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
@@ -8865,7 +8863,7 @@ struct llm_build_mamba : public llm_graph_context {
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
|
||||
cur = build_mamba_layer(gf, cur, state_copy, model, ubatch, il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
@@ -8906,6 +8904,7 @@ struct llm_build_mamba : public llm_graph_context {
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_model & model,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
@@ -227,7 +227,7 @@ struct gguf_value {
|
||||
for (size_t i = 0; i < arr_size; ++i) {
|
||||
memcpy(data.data() + type_size * i, &(*value.array)[i].value, type_size);
|
||||
}
|
||||
gguf_set_arr_data(ctx, k, arr_type, data.data(), data.size());
|
||||
gguf_set_arr_data(ctx, k, arr_type, data.data(), data.size() / type_size);
|
||||
}
|
||||
// TODO: handle nested arrays
|
||||
}
|
||||
@@ -317,7 +317,12 @@ struct model_variant {
|
||||
gguf_add_tensor(ctx_gguf, tensor);
|
||||
}
|
||||
|
||||
return gguf_write_to_file(ctx_gguf, fname, false);
|
||||
bool status = gguf_write_to_file(ctx_gguf, fname, false);
|
||||
|
||||
ggml_free(ctx);
|
||||
gguf_free(ctx_gguf);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
static void insert_from_arch(std::vector<model_variant> & variants, llm_arch arch) {
|
||||
@@ -762,9 +767,8 @@ int main(int argc, char ** argv) {
|
||||
std::mt19937 rng(42);
|
||||
|
||||
// TODO: multiple sequences per token
|
||||
const int64_t n_batch = 2048;
|
||||
const int64_t n_seq_len = 1024;
|
||||
std::uniform_int_distribution<int64_t> rand_seq_init_len(n_seq_len / 4, 3 * n_seq_len / 4);
|
||||
const int32_t n_batch = 2048;
|
||||
const int32_t n_seq_len = 1024;
|
||||
|
||||
llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
// TODO: batch with embeddings
|
||||
@@ -794,10 +798,10 @@ int main(int argc, char ** argv) {
|
||||
// const auto n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
// const auto n_embd = llama_model_n_embd(model);
|
||||
|
||||
for (int64_t n_seq_max : { 1, 2, 13 } ) {
|
||||
for (int32_t n_seq_max : { 1, 2, 13 } ) {
|
||||
|
||||
// TODO(later): context shift testing
|
||||
for (int64_t n_ctx : { n_seq_len * n_seq_max }) {
|
||||
for (int32_t n_ctx : { n_seq_len * n_seq_max }) {
|
||||
|
||||
std::vector<reference_logits> ref_outputs;
|
||||
|
||||
@@ -824,7 +828,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
for (bool shuffle : { false, true }) {
|
||||
|
||||
for (int64_t n_ubatch : { 1, 2, 512 } ) {
|
||||
for (int32_t n_ubatch : { 1, 2, 512 } ) {
|
||||
|
||||
std::vector<bool> valid(n_seq_max, true);
|
||||
|
||||
@@ -852,7 +856,7 @@ int main(int argc, char ** argv) {
|
||||
if (batch.n_tokens < n_batch) {
|
||||
const int64_t seq_len =
|
||||
std::min(n_batch - batch.n_tokens,
|
||||
(int64_t) ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
|
||||
ref_outputs[seq_id].prompt_len - seq_id_n_past[seq_id]);
|
||||
|
||||
ref_outputs[seq_id].add_to_batch(batch, seq_id_n_past[seq_id], seq_len, seq_id);
|
||||
seq_ids_in_batch.insert(seq_id);
|
||||
@@ -891,7 +895,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
fprintf(stdout,
|
||||
"Comparing output for '%s', with shuffle=%i, n_seq_max=%li, n_ctx=%li, n_ubatch=%li: ",
|
||||
"Comparing output for '%s', with shuffle=%i, n_seq_max=%i, n_ctx=%i, n_ubatch=%i: ",
|
||||
variant.name.c_str(), shuffle, n_seq_max, n_ctx, n_ubatch);
|
||||
if (std::all_of(valid.begin(), valid.end(), [](bool v) { return v; })) {
|
||||
fprintf(stdout, "\033[1;32mOK\033[0m\n");
|
||||
|
||||
Reference in New Issue
Block a user