mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	rework, targeting llama-server
This commit is contained in:
		@@ -314,6 +314,8 @@ struct llama_batch * llama_batch_get_one(
 | 
			
		||||
                 int32_t   n_tokens) {
 | 
			
		||||
    return new llama_batch{
 | 
			
		||||
        /*n_tokens       =*/ n_tokens,
 | 
			
		||||
        /*max_tokens     =*/ n_tokens,
 | 
			
		||||
        /*is_view        =*/ false,
 | 
			
		||||
        /*tokens         =*/ tokens,
 | 
			
		||||
        /*embd           =*/ nullptr,
 | 
			
		||||
        /*pos            =*/ nullptr,
 | 
			
		||||
@@ -326,6 +328,8 @@ struct llama_batch * llama_batch_get_one(
 | 
			
		||||
static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
 | 
			
		||||
    llama_batch * batch = new llama_batch{
 | 
			
		||||
        /*n_tokens       =*/ 0,
 | 
			
		||||
        /*max_tokens     =*/ n_tokens_alloc,
 | 
			
		||||
        /*is_view        =*/ false,
 | 
			
		||||
        /*tokens         =*/ nullptr,
 | 
			
		||||
        /*embd           =*/ nullptr,
 | 
			
		||||
        /*pos            =*/ nullptr,
 | 
			
		||||
@@ -364,50 +368,46 @@ struct llama_batch * llama_batch_init_from_embd(
 | 
			
		||||
            int32_t   seq_id) {
 | 
			
		||||
    struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1);
 | 
			
		||||
    memcpy(batch->embd, embd, n_embd * sizeof(float));
 | 
			
		||||
    for (int32_t i = 0; i < n_embd; i++) {
 | 
			
		||||
    for (size_t i = 0; i < n_embd; i++) {
 | 
			
		||||
        batch->pos     [i] = pos0 + i;
 | 
			
		||||
        batch->n_seq_id[i] = 1;
 | 
			
		||||
        batch->seq_id  [i][0] = seq_id;
 | 
			
		||||
    }
 | 
			
		||||
    return batch;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int32_t llama_batch_add_text(
 | 
			
		||||
int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) {
 | 
			
		||||
    return batch->n_tokens;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int32_t llama_batch_add_text_token(
 | 
			
		||||
        struct llama_batch * batch,
 | 
			
		||||
               llama_token * tokens,
 | 
			
		||||
                   size_t    n_tokens,
 | 
			
		||||
                   int32_t   pos0,
 | 
			
		||||
                   int32_t * seq_ids,
 | 
			
		||||
                   size_t    n_seq_ids) {
 | 
			
		||||
    if (batch->n_tokens + n_tokens > batch->n_tokens) {
 | 
			
		||||
        return -1;
 | 
			
		||||
               llama_token   token,
 | 
			
		||||
                 llama_pos   pos,
 | 
			
		||||
        const llama_seq_id * seq_ids,
 | 
			
		||||
                    size_t   n_seq_ids,
 | 
			
		||||
                     float   logits) {
 | 
			
		||||
    if (batch->n_tokens + 1 > batch->max_tokens) {
 | 
			
		||||
        return -1; // llama_batch size exceeded
 | 
			
		||||
    }
 | 
			
		||||
    if (batch->embd) {
 | 
			
		||||
        return -2;
 | 
			
		||||
        return -2; // embd is already set, cannot add text tokens
 | 
			
		||||
    }
 | 
			
		||||
    for (int32_t i = 0; i < n_tokens; i++) {
 | 
			
		||||
        batch->token   [batch->n_tokens + i] = tokens[i];
 | 
			
		||||
        batch->pos     [batch->n_tokens + i] = pos0 + i;
 | 
			
		||||
        batch->n_seq_id[batch->n_tokens + i] = n_seq_ids;
 | 
			
		||||
        for (int32_t j = 0; j < n_seq_ids; j++) {
 | 
			
		||||
            batch->seq_id[batch->n_tokens + i][j] = seq_ids[j];
 | 
			
		||||
        }
 | 
			
		||||
    batch->token   [batch->n_tokens] = token;
 | 
			
		||||
    batch->pos     [batch->n_tokens] = pos;
 | 
			
		||||
    batch->n_seq_id[batch->n_tokens] = n_seq_ids;
 | 
			
		||||
    for (size_t j = 0; j < n_seq_ids; j++) {
 | 
			
		||||
        batch->seq_id[batch->n_tokens][j] = seq_ids[j];
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int32_t llama_batch_add_text(
 | 
			
		||||
        struct llama_batch * batch,
 | 
			
		||||
               llama_token * tokens,
 | 
			
		||||
                   size_t    n_tokens,
 | 
			
		||||
                   int32_t   pos0,
 | 
			
		||||
                   int32_t   seq_id) {
 | 
			
		||||
    std::array<int32_t, 1> seq_ids = { seq_id };
 | 
			
		||||
    return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size());
 | 
			
		||||
    batch->logits  [batch->n_tokens] = logits;
 | 
			
		||||
    batch->n_tokens++;
 | 
			
		||||
    return 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int32_t llama_batch_set_logits(
 | 
			
		||||
        struct llama_batch * batch,
 | 
			
		||||
                   int32_t   pos,
 | 
			
		||||
                   int32_t   seq_id) {
 | 
			
		||||
                 llama_pos   pos,
 | 
			
		||||
              llama_seq_id   seq_id) {
 | 
			
		||||
    for (int32_t i = 0; i < batch->n_tokens; i++) {
 | 
			
		||||
        // find the token having seq_id
 | 
			
		||||
        for (int32_t j = 0; j < batch->n_seq_id[i]; j++) {
 | 
			
		||||
@@ -415,28 +415,74 @@ int32_t llama_batch_set_logits(
 | 
			
		||||
                // found the sequence
 | 
			
		||||
                if (pos == -1 || pos == batch->pos[i]) {
 | 
			
		||||
                    batch->logits[i] = true;
 | 
			
		||||
                    break;
 | 
			
		||||
                    return 0;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    return -1; // not found
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int32_t llama_batch_set_logits_last(struct llama_batch * batch) {
 | 
			
		||||
    if (batch->n_tokens == 0) {
 | 
			
		||||
        return -1;
 | 
			
		||||
    }
 | 
			
		||||
    batch->logits[batch->n_tokens - 1] = true;
 | 
			
		||||
    return 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void llama_batch_clear(struct llama_batch * batch) {
 | 
			
		||||
    batch->n_tokens = 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void llama_batch_free(struct llama_batch * batch) {
 | 
			
		||||
    if (batch->token)    free(batch->token);
 | 
			
		||||
    if (batch->embd)     free(batch->embd);
 | 
			
		||||
    if (batch->pos)      free(batch->pos);
 | 
			
		||||
    if (batch->n_seq_id) free(batch->n_seq_id);
 | 
			
		||||
    if (batch->seq_id) {
 | 
			
		||||
        for (int i = 0; batch->seq_id[i] != nullptr; ++i) {
 | 
			
		||||
            free(batch->seq_id[i]);
 | 
			
		||||
        }
 | 
			
		||||
        free(batch->seq_id);
 | 
			
		||||
struct llama_batch * llama_batch_get_view(
 | 
			
		||||
        struct llama_batch * batch,
 | 
			
		||||
                   int32_t   offset,
 | 
			
		||||
                   int32_t   n_tokens) {
 | 
			
		||||
    if (batch->embd) {
 | 
			
		||||
        return nullptr; // not yet supported
 | 
			
		||||
    }
 | 
			
		||||
    llama_batch * batch_view = new llama_batch{
 | 
			
		||||
        /*n_tokens       =*/ n_tokens,
 | 
			
		||||
        /*max_tokens     =*/ n_tokens,
 | 
			
		||||
        /*is_view        =*/ true,
 | 
			
		||||
        /*tokens         =*/ batch->token    + offset,
 | 
			
		||||
        /*embd           =*/ nullptr,
 | 
			
		||||
        /*pos            =*/ batch->pos      + offset,
 | 
			
		||||
        /*n_seq_id       =*/ batch->n_seq_id + offset,
 | 
			
		||||
        /*seq_id         =*/ batch->seq_id   + offset,
 | 
			
		||||
        /*logits         =*/ batch->logits   + offset,
 | 
			
		||||
    };
 | 
			
		||||
    return batch_view;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct llama_batch_token_info llama_batch_get_token_info(
 | 
			
		||||
        struct llama_batch * batch,
 | 
			
		||||
                   int32_t   i) {
 | 
			
		||||
    GGML_ASSERT(i >= 0 && i < batch->n_tokens);
 | 
			
		||||
    return llama_batch_token_info{
 | 
			
		||||
        /*token    =*/ batch->token   [i],
 | 
			
		||||
        /*pos      =*/ batch->pos     [i],
 | 
			
		||||
        /*n_seq_id =*/ batch->n_seq_id[i],
 | 
			
		||||
        /*seq_id   =*/ batch->seq_id  [i],
 | 
			
		||||
        /*logits   =*/ batch->logits  [i],
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void llama_batch_free(struct llama_batch * batch) {
 | 
			
		||||
    // do not free the members if it's a view
 | 
			
		||||
    if (!batch->is_view) {
 | 
			
		||||
        if (batch->token)    free(batch->token);
 | 
			
		||||
        if (batch->embd)     free(batch->embd);
 | 
			
		||||
        if (batch->pos)      free(batch->pos);
 | 
			
		||||
        if (batch->n_seq_id) free(batch->n_seq_id);
 | 
			
		||||
        if (batch->seq_id) {
 | 
			
		||||
            for (int i = 0; batch->seq_id[i] != nullptr; ++i) {
 | 
			
		||||
                free(batch->seq_id[i]);
 | 
			
		||||
            }
 | 
			
		||||
            free(batch->seq_id);
 | 
			
		||||
        }
 | 
			
		||||
        if (batch->logits)   free(batch->logits);
 | 
			
		||||
    }
 | 
			
		||||
    if (batch->logits)   free(batch->logits);
 | 
			
		||||
    delete batch;
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -20,6 +20,8 @@
 | 
			
		||||
//
 | 
			
		||||
struct llama_batch {
 | 
			
		||||
    int32_t n_tokens;
 | 
			
		||||
    int32_t max_tokens;
 | 
			
		||||
    bool is_view;
 | 
			
		||||
 | 
			
		||||
    llama_token  *  token;
 | 
			
		||||
    float        *  embd;
 | 
			
		||||
 
 | 
			
		||||
@@ -9978,8 +9978,8 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) {
 | 
			
		||||
 | 
			
		||||
int32_t llama_encode(
 | 
			
		||||
        struct llama_context * ctx,
 | 
			
		||||
          struct llama_batch   batch) {
 | 
			
		||||
    const int ret = llama_encode_impl(*ctx, batch);
 | 
			
		||||
          struct llama_batch * batch) {
 | 
			
		||||
    const int ret = llama_encode_impl(*ctx, *batch);
 | 
			
		||||
    if (ret != 0) {
 | 
			
		||||
        LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
 | 
			
		||||
    }
 | 
			
		||||
@@ -9989,8 +9989,8 @@ int32_t llama_encode(
 | 
			
		||||
 | 
			
		||||
int32_t llama_decode(
 | 
			
		||||
        struct llama_context * ctx,
 | 
			
		||||
          struct llama_batch   batch) {
 | 
			
		||||
    const int ret = llama_decode_impl(*ctx, batch);
 | 
			
		||||
          struct llama_batch * batch) {
 | 
			
		||||
    const int ret = llama_decode_impl(*ctx, *batch);
 | 
			
		||||
    if (ret != 0) {
 | 
			
		||||
        LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user