mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
memory : move the recurrent state into the memory context
This commit is contained in:
@@ -1092,12 +1092,15 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
||||
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
||||
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
||||
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem),
|
||||
n_rs(mem->size), head(0), rs_z(0), size(mem->size) {
|
||||
}
|
||||
|
||||
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
||||
llama_memory_recurrent * mem,
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)),
|
||||
n_rs(mem->n), head(mem->head), rs_z(mem->rs_z), size(mem->size) {
|
||||
}
|
||||
|
||||
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
|
||||
|
||||
@@ -1138,19 +1141,19 @@ const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
|
||||
}
|
||||
|
||||
uint32_t llama_memory_recurrent_context::get_n_rs() const {
|
||||
return is_full ? mem->size : mem->n;
|
||||
return n_rs;
|
||||
}
|
||||
|
||||
uint32_t llama_memory_recurrent_context::get_head() const {
|
||||
return is_full ? 0 : mem->head;
|
||||
return head;
|
||||
}
|
||||
|
||||
int32_t llama_memory_recurrent_context::get_rs_z() const {
|
||||
return is_full ? 0 : mem->rs_z;
|
||||
return rs_z;
|
||||
}
|
||||
|
||||
uint32_t llama_memory_recurrent_context::get_size() const {
|
||||
return mem->size;
|
||||
return size;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
|
||||
@@ -1162,5 +1165,5 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
|
||||
}
|
||||
|
||||
int32_t llama_memory_recurrent_context::s_copy(int i) const {
|
||||
return mem->cells[i + mem->head].src0;
|
||||
return mem->cells[i + head].src0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user