diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6b101f4869..81663c4001 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -261,7 +261,7 @@ void llama_context_base::init() { LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const size_t max_nodes = this->max_nodes(); + const size_t max_nodes = this->graph_max_nodes(); LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); @@ -420,10 +420,6 @@ const llama_model & llama_context_base::get_model() const { return model; } -const llama_cparams & llama_context_base::get_cparams() const { - return cparams; -} - uint32_t llama_context_base::n_ctx() const { return cparams.n_ctx; } @@ -452,10 +448,6 @@ uint32_t llama_context_base::n_threads_batch() const { return cparams.n_threads_batch; } -int32_t llama_context_base::max_nodes() const { - return std::max(8192, 5*model.n_tensors()); -} - llama_kv_cache * llama_context_base::get_kv_self() { LLAMA_LOG_WARN("%s: llama_context_base does not have a KV cache\n", __func__); return nullptr; @@ -573,10 +565,6 @@ float * llama_context_base::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } -int64_t llama_context_base::n_pos_per_token() const { - return model.arch == LLM_ARCH_QWEN2VL ? 4 : 1; -} - void llama_context_base::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -1007,6 +995,10 @@ int llama_context_base::decode(llama_batch & inp_batch) { // input // +int64_t llama_context_base::n_pos_per_token() const { + return model.arch == LLM_ARCH_QWEN2VL ? 4 : 1; +} + void llama_context_base::input_set(const llama_ubatch & ubatch) { const llama_hparams & hparams = model.hparams; @@ -1391,6 +1383,10 @@ void llama_context_base::output_reorder() { // graph // +int32_t llama_context_base::graph_max_nodes() const { + return std::max(8192, 5*model.n_tensors()); +} + ggml_cgraph * llama_context_base::graph_init() { inp = {}; @@ -1402,7 +1398,7 @@ ggml_cgraph * llama_context_base::graph_init() { ctx_compute.reset(ggml_init(params)); - return ggml_new_graph_custom(ctx_compute.get(), max_nodes(), false); + return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } llama_graph_result llama_context_base::graph_build( @@ -2034,7 +2030,7 @@ private: size_t llama_context_base::state_get_size() { llama_io_write_dummy io; try { - return state_get_data(io); + return state_write_data(io); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); return 0; @@ -2044,7 +2040,7 @@ size_t llama_context_base::state_get_size() { size_t llama_context_base::state_get_data(uint8_t * dst, size_t size) { llama_io_write_buffer io(dst, size); try { - return state_get_data(io); + return state_write_data(io); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); return 0; @@ -2054,7 +2050,7 @@ size_t llama_context_base::state_get_data(uint8_t * dst, size_t size) { size_t llama_context_base::state_set_data(const uint8_t * src, size_t size) { llama_io_read_buffer io(src, size); try { - return state_set_data(io); + return state_read_data(io); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; @@ -2064,7 +2060,7 @@ size_t llama_context_base::state_set_data(const uint8_t * src, size_t size) { size_t llama_context_base::state_seq_get_size(llama_seq_id seq_id) { llama_io_write_dummy io; try { - return state_seq_get_data(io, seq_id); + return state_seq_write_data(io, seq_id); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); return 0; @@ -2074,7 +2070,7 @@ size_t llama_context_base::state_seq_get_size(llama_seq_id seq_id) { size_t llama_context_base::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { llama_io_write_buffer io(dst, size); try { - return state_seq_get_data(io, seq_id); + return state_seq_write_data(io, seq_id); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); return 0; @@ -2084,7 +2080,7 @@ size_t llama_context_base::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst size_t llama_context_base::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { llama_io_read_buffer io(src, size); try { - return state_seq_set_data(io, seq_id); + return state_seq_read_data(io, seq_id); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; @@ -2123,7 +2119,7 @@ bool llama_context_base::state_load_file(const char * filepath, llama_token * to const size_t n_state_size_cur = file.size() - file.tell(); llama_io_read_file io( &file); - const size_t n_read = state_set_data(io); + const size_t n_read = state_read_data(io); if (n_read != n_state_size_cur) { LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); @@ -2146,7 +2142,7 @@ bool llama_context_base::state_save_file(const char * filepath, const llama_toke // save the context state using stream saving llama_io_write_file io(&file); - state_get_data(io); + state_write_data(io); return true; } @@ -2182,7 +2178,7 @@ size_t llama_context_base::state_seq_load_file(llama_seq_id seq_id, const char * { const size_t state_size = file.size() - file.tell(); llama_io_read_file io(&file); - const size_t nread = state_seq_set_data(io, seq_id); + const size_t nread = state_seq_read_data(io, seq_id); if (!nread) { LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); return 0; @@ -2206,7 +2202,7 @@ size_t llama_context_base::state_seq_save_file(llama_seq_id seq_id, const char * // save the context state using stream saving llama_io_write_file io(&file); - state_seq_get_data(io, seq_id); + state_seq_write_data(io, seq_id); const size_t res = file.tell(); GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); @@ -2214,7 +2210,7 @@ size_t llama_context_base::state_seq_save_file(llama_seq_id seq_id, const char * return res; } -size_t llama_context_base::state_get_data(llama_io_write_i & io) { +size_t llama_context_base::state_write_data(llama_io_write_i & io) { LLAMA_LOG_DEBUG("%s: writing state\n", __func__); // write model info @@ -2287,7 +2283,7 @@ size_t llama_context_base::state_get_data(llama_io_write_i & io) { return io.n_bytes(); } -size_t llama_context_base::state_set_data(llama_io_read_i & io) { +size_t llama_context_base::state_read_data(llama_io_read_i & io) { LLAMA_LOG_DEBUG("%s: reading state\n", __func__); // read model info @@ -2368,13 +2364,13 @@ size_t llama_context_base::state_set_data(llama_io_read_i & io) { return io.n_bytes(); } -size_t llama_context_base::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { +size_t llama_context_base::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); return io.n_bytes(); } -size_t llama_context_base::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { +size_t llama_context_base::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); return io.n_bytes(); @@ -2400,9 +2396,6 @@ llama_context_kv_self::llama_context_kv_self( LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - // build worst-case graph for encoder if a model contains encoder - is_encoding = llama_model_has_encoder(&model); // TODO: model.has_encoder() - uint32_t kv_size = cparams.n_ctx; ggml_type type_k = params.type_k; ggml_type type_v = params.type_v; @@ -2537,8 +2530,6 @@ void llama_context_kv_self::kv_self_update() { } int llama_context_kv_self::encode(llama_batch & inp_batch) { - is_encoding = true; - if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; @@ -2589,7 +2580,6 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { output_ids[i] = i; } - inp_embd_enc = NULL; n_outputs = n_tokens; //batch_manager->prepare(ubatch); @@ -2624,65 +2614,48 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); - if (llama_model_has_decoder(&model)) { - embd_enc.resize(n_tokens*n_embd); - float * embd_out = embd_enc.data(); + GGML_ASSERT(embd != nullptr); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); - GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd; - // remember the sequence ids used during the encoding - needed for cross attention later - seq_ids_enc.resize(n_tokens); - for (int32_t i = 0; i < n_tokens; i++) { - for (int s = 0; s < ubatch.n_seq_id[i]; s++) { - llama_seq_id seq_id = ubatch.seq_id[i][s]; - seq_ids_enc[i].insert(seq_id); - } - } - } else { - GGML_ASSERT(embd != nullptr); + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings + auto & embd_seq_out = embd_seq; + embd_seq_out.clear(); - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd; + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits - GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings - auto & embd_seq_out = embd_seq; - embd_seq_out.clear(); - - GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits - - for (int32_t i = 0; i < n_tokens; i++) { - const llama_seq_id seq_id = ubatch.seq_id[i][0]; - if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { - continue; - } - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + for (int32_t i = 0; i < n_tokens; i++) { + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // TODO: this likely should be the same logic as in llama_decoder_internal, but better to - // wait for an encoder model that requires this pooling type in order to test it - // https://github.com/ggerganov/llama.cpp/pull/9510 - GGML_ABORT("RANK pooling not implemented yet"); + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); } - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } - } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // TODO: this likely should be the same logic as in llama_decoder_internal, but better to + // wait for an encoder model that requires this pooling type in order to test it + // https://github.com/ggerganov/llama.cpp/pull/9510 + GGML_ABORT("RANK pooling not implemented yet"); + } + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } } } @@ -2694,8 +2667,6 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { } int llama_context_kv_self::decode(llama_batch & inp_batch) { - is_encoding = false; - if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; @@ -3039,7 +3010,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { if (inp.self_kq_mask || inp.self_kq_mask_swa) { // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache. - if (cparams.causal_attn && !is_encoding) { + if (cparams.causal_attn) { const int64_t n_kv = kv_self.n; const int64_t n_tokens = ubatch.n_tokens; const int64_t n_seq_tokens = ubatch.n_seq_tokens; @@ -3116,7 +3087,7 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { const int64_t n_seq_tokens = ubatch.n_seq_tokens; const int64_t n_seqs = ubatch.n_seqs; // when using kv cache, the mask needs to match the kv cache size - const int64_t n_stride = hparams.causal_attn && !is_encoding ? kv_self.n : n_tokens; + const int64_t n_stride = n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(inp.self_kq_mask->buffer)); @@ -3175,50 +3146,9 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) { } } } - - if (!is_encoding && inp_embd_enc) { - assert(inp_embd_enc->type == GGML_TYPE_F32); - assert((size_t) ggml_nelements(inp_embd_enc) == embd_enc.size()); - - ggml_backend_tensor_set(inp_embd_enc, embd_enc.data(), 0, ggml_nbytes(inp_embd_enc)); - } - - if (!is_encoding && inp_kq_mask_cross) { - const int64_t n_output_enc = embd_enc.size() / hparams.n_embd; - const int64_t n_tokens = ubatch.n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(inp_kq_mask_cross->buffer)); - GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing - - float * data = (float *) inp_kq_mask_cross->data; - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_output_enc; ++i) { - float f = -INFINITY; - for (int s = 0; s < ubatch.n_seq_id[j]; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[j][s]; - if (seq_ids_enc[i].find(seq_id) != seq_ids_enc[i].end()) { - f = 0.0f; - } - } - data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f; - } - } - - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_output_enc; ++j) { - data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY; - } - } - } - } } ggml_cgraph * llama_context_kv_self::graph_init() { - inp_embd_enc = nullptr; - inp_kq_mask_cross = nullptr; - inp = {}; return llama_context_base::graph_init(); @@ -3441,7 +3371,7 @@ void llama_context_kv_self::build_kv_self_defrag( // - x2 for keys and values //const uint32_t max_moves = max_nodes()/(6*n_layer); // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (max_nodes() - 2*n_layer)/(6*n_layer); + const uint32_t max_moves = (graph_max_nodes() - 2*n_layer)/(6*n_layer); // determine which KV cells to move where // @@ -3689,39 +3619,10 @@ void llama_context_kv_self::build_kv_self_defrag( #endif } -ggml_tensor * llama_context_kv_self::build_inp_embd_enc( - ggml_context * ctx0) { - const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; - - // TODO: not sure if this is correct - const int32_t n_outputs_enc = embd_enc.size() / n_embd; - - inp_embd_enc = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc); - ggml_set_input(inp_embd_enc); - - return inp_embd_enc; -} - -ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross( - ggml_context * ctx0, - int32_t n_tokens) { - const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; - - // TODO: not sure if this is correct - const int32_t n_outputs_enc = embd_enc.size() / n_embd; - - inp_kq_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - ggml_set_input(inp_kq_mask_cross); - - return inp_kq_mask_cross; -} - // state save/load -size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { - llama_context_base::state_get_data(io); +size_t llama_context_kv_self::state_write_data(llama_io_write_i & io) { + llama_context_base::state_write_data(io); LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); kv_self.state_write(io); @@ -3729,8 +3630,8 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { return io.n_bytes(); } -size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) { - llama_context_base::state_set_data(io); +size_t llama_context_kv_self::state_read_data(llama_io_read_i & io) { + llama_context_base::state_read_data(io); LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); kv_self.state_read(io); @@ -3738,16 +3639,16 @@ size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) { return io.n_bytes(); } -size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { - llama_context_base::state_seq_get_data(io, seq_id); +size_t llama_context_kv_self::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { + llama_context_base::state_seq_write_data(io, seq_id); kv_self.state_write(io, seq_id); return io.n_bytes(); } -size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { - llama_context_base::state_seq_set_data(io, seq_id); +size_t llama_context_kv_self::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { + llama_context_base::state_seq_read_data(io, seq_id); kv_self.state_read(io, seq_id); @@ -4603,54 +4504,568 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( // state save/load -size_t llama_context_recurrent::state_get_data(llama_io_write_i & io) { - llama_context_base::state_get_data(io); +size_t llama_context_recurrent::state_write_data(llama_io_write_i & io) { + llama_context_base::state_write_data(io); kv_self.state_write(io); return io.n_bytes(); } -size_t llama_context_recurrent::state_set_data(llama_io_read_i & io) { - llama_context_base::state_set_data(io); +size_t llama_context_recurrent::state_read_data(llama_io_read_i & io) { + llama_context_base::state_read_data(io); kv_self.state_read(io); return io.n_bytes(); } -size_t llama_context_recurrent::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { - llama_context_base::state_seq_get_data(io, seq_id); +size_t llama_context_recurrent::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { + llama_context_base::state_seq_write_data(io, seq_id); kv_self.state_write(io, seq_id); return io.n_bytes(); } -size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { - llama_context_base::state_seq_set_data(io, seq_id); +size_t llama_context_recurrent::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { + llama_context_base::state_seq_read_data(io, seq_id); kv_self.state_read(io, seq_id); return io.n_bytes(); } +// +// llama_context_enc +// + +int llama_context_enc::encode(llama_batch & inp_batch) { + if (inp_batch.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); + return -1; + } + + // temporary allocate memory for the input batch if needed + llama_batch_allocr batch_allocr(inp_batch, 0); + + const llama_batch & batch = batch_allocr.batch; + + const int32_t n_tokens = batch.n_tokens; + + const auto & hparams = model.hparams; + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + if (batch.token) { + for (int32_t i = 0; i < n_tokens; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + return -1; + } + } + } + + // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot + GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + + n_queued_tokens += n_tokens; + + const int64_t n_embd = hparams.n_embd; + + sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + + const llama_ubatch ubatch = sbatch.split_simple(n_tokens); + + // reserve output buffer + if (output_reserve(n_tokens) < n_tokens) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); + return -2; + }; + + for (int32_t i = 0; i < n_tokens; ++i) { + output_ids[i] = i; + } + + n_outputs = n_tokens; + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch); + + ggml_backend_sched_alloc_graph(sched.get(), gf); + + input_set(ubatch); + + const auto compute_status = graph_compute(gf, n_tokens > 1); + switch (compute_status) { + case GGML_STATUS_SUCCESS: + break; + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } + + auto * t_embd = res.t_embd_pooled ? res.t_embd_pooled : res.t_embd; + + // extract embeddings + if (t_embd) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + GGML_ASSERT(embd != nullptr); + + // extract token embeddings + float * embd_out = embd; + + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_tokens*n_embd*sizeof(float)); + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings + auto & embd_seq_out = embd_seq; + embd_seq_out.clear(); + + GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits + + for (int32_t i = 0; i < n_tokens; i++) { + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // TODO: this likely should be the same logic as in llama_decoder_internal, but better to + // wait for an encoder model that requires this pooling type in order to test it + // https://github.com/ggerganov/llama.cpp/pull/9510 + GGML_ABORT("RANK pooling not implemented yet"); + } + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } + } + } + + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + + cross->n_outputs = n_tokens; + cross->embd_enc = embd; + + // remember the sequence ids used during the encoding - needed for cross attention later + cross->seq_ids_enc.resize(n_tokens); + for (int32_t i = 0; i < n_tokens; i++) { + for (int s = 0; s < ubatch.n_seq_id[i]; s++) { + llama_seq_id seq_id = ubatch.seq_id[i][s]; + cross->seq_ids_enc[i].insert(seq_id); + } + } + + return 0; +} + +// +// llama_context_dec +// + +void llama_context_dec::reserve() { + // simulate full KV cache + cross->n_outputs = cparams.n_ubatch; + + LLAMA_LOG_DEBUG("%s: n_outputs = %u\n", __func__, cross->n_outputs); + + llama_context_kv_self::reserve(); +} + +void llama_context_dec::input_set(const llama_ubatch & ubatch) { + // call base functionality + llama_context_kv_self::input_set(ubatch); + + if (inp.cross_embd) { + assert(inp.cross_embd->type == GGML_TYPE_F32); + assert(ggml_nelements(inp.cross_embd) == cross->n_outputs*model.hparams.n_embd); + + ggml_backend_tensor_set(inp.cross_embd, cross->embd_enc, 0, ggml_nbytes(inp.cross_embd)); + } + + if (inp.cross_kq_mask) { + const int64_t n_output_enc = cross->n_outputs; + const int64_t n_tokens = ubatch.n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(inp.cross_kq_mask->buffer)); + GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing + + float * data = (float *) inp.cross_kq_mask->data; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_output_enc; ++i) { + float f = -INFINITY; + for (int s = 0; s < ubatch.n_seq_id[j]; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[j][s]; + if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) { + f = 0.0f; + } + } + data[h*(n_output_enc*n_tokens) + j*n_output_enc + i] = f; + } + } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_output_enc; ++j) { + data[h*(n_output_enc*n_tokens) + i*n_output_enc + j] = -INFINITY; + } + } + } + } +} + +ggml_cgraph * llama_context_dec::graph_init() { + inp = {}; + + return llama_context_kv_self::graph_init(); +} + +ggml_tensor * llama_context_dec::build_inp_cross_embd( + ggml_context * ctx0) { + const auto & hparams = model.hparams; + const int64_t n_embd = hparams.n_embd; + + const int32_t n_outputs_enc = cross->n_outputs; + + inp.cross_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc); + ggml_set_input(inp.cross_embd); + + return inp.cross_embd; +} + +void llama_context_dec::build_attn_inp( + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa) { + llama_context_kv_self::build_attn_inp(ctx0, n_tokens, causal, swa); + + const int32_t n_outputs_enc = cross->n_outputs; + + inp.cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(inp.cross_kq_mask); + + inp.cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp.cross_kq_mask, GGML_TYPE_F16) : inp.cross_kq_mask; +} + +ggml_tensor * llama_context_dec::build_attn_cross( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + float kq_scale, + int il) { + GGML_UNUSED(il); + + const auto & kq_mask = inp.cross_kq_mask_cnv; + + ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); + //cb(q, "q", il); + + ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); + //cb(k, "k", il); + + ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); + //cb(k, "v", il); + + ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, false, kq_scale); + + return cur; +} + // // llama_context_enc_dec // llama_context_enc_dec::llama_context_enc_dec( const llama_model & model, - llama_context_params params) : - llama_context_enc(model, params, LLAMA_GRAPH_TYPE_ENCODER), - ctx_dec(model, params, LLAMA_GRAPH_TYPE_DECODER) { + llama_context_params params) { LLAMA_LOG_INFO("%s: constructing llama_context_enc_dec\n", __func__); + + ctx_enc = std::make_unique(model, params, LLAMA_GRAPH_TYPE_ENCODER); + ctx_dec = std::make_unique(model, params, LLAMA_GRAPH_TYPE_DECODER); + + ctx_enc->cross = ✗ + ctx_dec->cross = ✗ } llama_context_enc_dec::~llama_context_enc_dec() { LLAMA_LOG_INFO("%s: destructing llama_context_enc_dec\n", __func__); } +void llama_context_enc_dec::init() { + ctx_enc->init(); + ctx_dec->init(); +} + +void llama_context_enc_dec::synchronize() { + ctx_enc->synchronize(); + ctx_dec->synchronize(); +} + +const llama_model & llama_context_enc_dec::get_model() const { + return ctx_enc->get_model(); +} + +uint32_t llama_context_enc_dec::n_ctx() const { + return ctx_dec->n_ctx(); +} + +uint32_t llama_context_enc_dec::n_ctx_per_seq() const { + return ctx_dec->n_ctx_per_seq(); +} + +uint32_t llama_context_enc_dec::n_batch() const { + return ctx_dec->n_batch(); +} + +uint32_t llama_context_enc_dec::n_ubatch() const { + return ctx_dec->n_ubatch(); +} + +uint32_t llama_context_enc_dec::n_seq_max() const { + return ctx_dec->n_seq_max(); +} + +uint32_t llama_context_enc_dec::n_threads() const { + return ctx_dec->n_threads(); +} + +uint32_t llama_context_enc_dec::n_threads_batch() const { + return ctx_dec->n_threads_batch(); +} + +llama_kv_cache * llama_context_enc_dec::get_kv_self() { + return ctx_dec->get_kv_self(); +} + +const llama_kv_cache * llama_context_enc_dec::get_kv_self() const { + return ctx_dec->get_kv_self(); +} + +void llama_context_enc_dec::kv_self_update() { + ctx_dec->kv_self_update(); +} + +enum llama_pooling_type llama_context_enc_dec::pooling_type() const { + return ctx_enc->pooling_type(); +} + +float * llama_context_enc_dec::get_logits() { + return ctx_dec->get_logits(); +} + +float * llama_context_enc_dec::get_logits_ith(int32_t i) { + return ctx_dec->get_logits_ith(i); +} + +float * llama_context_enc_dec::get_embeddings() { + return ctx_enc->get_embeddings(); +} + +float * llama_context_enc_dec::get_embeddings_ith(int32_t i) { + return ctx_enc->get_embeddings_ith(i); +} + +float * llama_context_enc_dec::get_embeddings_seq(llama_seq_id seq_id) { + return ctx_enc->get_embeddings_seq(seq_id); +} + +void llama_context_enc_dec::attach_threadpool( + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch) { + // TODO: attach to both - not sure if this is correct + ctx_enc->attach_threadpool(threadpool, threadpool_batch); + ctx_dec->attach_threadpool(threadpool, threadpool_batch); +} + +void llama_context_enc_dec::detach_threadpool() { + ctx_enc->detach_threadpool(); + ctx_dec->detach_threadpool(); +} + +void llama_context_enc_dec::set_n_threads(int32_t n_threads, int32_t n_threads_batch) { + ctx_enc->set_n_threads(n_threads, n_threads_batch); + ctx_dec->set_n_threads(n_threads, n_threads_batch); +} + +void llama_context_enc_dec::set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) { + ctx_enc->set_abort_callback(abort_callback, abort_callback_data); + ctx_dec->set_abort_callback(abort_callback, abort_callback_data); +} + +void llama_context_enc_dec::set_embeddings(bool value) { + GGML_UNUSED(value); + LLAMA_LOG_WARN("%s: set_embeddings() not supported for llama_context_enc_dec\n", __func__); +} + +void llama_context_enc_dec::set_causal_attn(bool value) { + GGML_UNUSED(value); + LLAMA_LOG_WARN("%s: set_causal_attn() not supported for llama_context_enc_dec\n", __func__); +} + +void llama_context_enc_dec::set_adapter_lora( + llama_adapter_lora * adapter, + float scale) { + ctx_dec->set_adapter_lora(adapter, scale); +} + +bool llama_context_enc_dec::rm_adapter_lora( + llama_adapter_lora * adapter) { + return ctx_dec->rm_adapter_lora(adapter); +} + +void llama_context_enc_dec::clear_adapter_lora() { + ctx_dec->clear_adapter_lora(); +} + +bool llama_context_enc_dec::apply_adapter_cvec( + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + return ctx_dec->apply_adapter_cvec(data, len, n_embd, il_start, il_end); +} + +int llama_context_enc_dec::encode(llama_batch & inp_batch) { + return ctx_enc->encode(inp_batch); +} + +int llama_context_enc_dec::decode(llama_batch & inp_batch) { + return ctx_dec->decode(inp_batch); +} + +// +// perf +// + +llama_perf_context_data llama_context_enc_dec::perf_get_data() const { + return ctx_dec->perf_get_data(); +} + +void llama_context_enc_dec::perf_reset() { + ctx_enc->perf_reset(); + ctx_dec->perf_reset(); +} + +// +// state save/load +// + +size_t llama_context_enc_dec::state_get_size() { + GGML_ABORT("TODO: implement"); +} + +size_t llama_context_enc_dec::state_get_data( uint8_t * dst, size_t size) { + GGML_UNUSED(dst); + GGML_UNUSED(size); + GGML_ABORT("TODO: implement"); +} + +size_t llama_context_enc_dec::state_set_data(const uint8_t * src, size_t size) { + GGML_UNUSED(src); + GGML_UNUSED(size); + GGML_ABORT("TODO: implement"); +} + +size_t llama_context_enc_dec::state_seq_get_size(llama_seq_id seq_id) { + GGML_UNUSED(seq_id); + GGML_ABORT("TODO: implement"); +} + +size_t llama_context_enc_dec::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { + GGML_UNUSED(seq_id); + GGML_UNUSED(dst); + GGML_UNUSED(size); + GGML_ABORT("TODO: implement"); +} + +size_t llama_context_enc_dec::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { + GGML_UNUSED(seq_id); + GGML_UNUSED(src); + GGML_UNUSED(size); + GGML_ABORT("TODO: implement"); +} + +bool llama_context_enc_dec::state_load_file( + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out) { + GGML_UNUSED(filepath); + GGML_UNUSED(tokens_out); + GGML_UNUSED(n_token_capacity); + GGML_UNUSED(n_token_count_out); + GGML_ABORT("TODO: implement"); +} + +bool llama_context_enc_dec::state_save_file( + const char * filepath, + const llama_token * tokens, + size_t n_token_count) { + GGML_UNUSED(filepath); + GGML_UNUSED(tokens); + GGML_UNUSED(n_token_count); + GGML_ABORT("TODO: implement"); +} + +size_t llama_context_enc_dec::state_seq_load_file( + llama_seq_id seq_id, + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out) { + GGML_UNUSED(seq_id); + GGML_UNUSED(filepath); + GGML_UNUSED(tokens_out); + GGML_UNUSED(n_token_capacity); + GGML_UNUSED(n_token_count_out); + GGML_ABORT("TODO: implement"); +} + +size_t llama_context_enc_dec::state_seq_save_file( + llama_seq_id seq_id, + const char * filepath, + const llama_token * tokens, + size_t n_token_count) { + GGML_UNUSED(seq_id); + GGML_UNUSED(filepath); + GGML_UNUSED(tokens); + GGML_UNUSED(n_token_count); + GGML_ABORT("TODO: implement"); +} + // // interface implementation // diff --git a/src/llama-context.h b/src/llama-context.h index d647a426cd..3165865a73 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -30,8 +30,7 @@ public: virtual void synchronize() = 0; - virtual const llama_model & get_model() const = 0; - virtual const llama_cparams & get_cparams() const = 0; + virtual const llama_model & get_model() const = 0; virtual uint32_t n_ctx() const = 0; virtual uint32_t n_ctx_per_seq() const = 0; @@ -42,8 +41,6 @@ public: virtual uint32_t n_threads() const = 0; virtual uint32_t n_threads_batch() const = 0; - virtual int32_t max_nodes() const = 0; - // self-attention: // if the context does not have a KV cache, return nullptr @@ -62,8 +59,6 @@ public: virtual float * get_embeddings_ith(int32_t i) = 0; virtual float * get_embeddings_seq(llama_seq_id seq_id) = 0; - virtual int64_t n_pos_per_token() const = 0; // vision - virtual void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) = 0; @@ -190,8 +185,7 @@ protected: virtual void reserve(); public: - const llama_model & get_model() const override; - const llama_cparams & get_cparams() const override; + const llama_model & get_model() const override; uint32_t n_ctx() const override; uint32_t n_ctx_per_seq() const override; @@ -202,15 +196,9 @@ public: uint32_t n_threads() const override; uint32_t n_threads_batch() const override; - int32_t max_nodes() const override; - - // self-attention: - - // if the context does not have a KV cache, return nullptr llama_kv_cache * get_kv_self() override; const llama_kv_cache * get_kv_self() const override; - // if the context does not have a KV cache, noop void kv_self_update() override; enum llama_pooling_type pooling_type() const override; @@ -222,8 +210,6 @@ public: float * get_embeddings_ith(int32_t i) override; float * get_embeddings_seq(llama_seq_id seq_id) override; - int64_t n_pos_per_token() const override; // vision - void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) override; @@ -261,6 +247,8 @@ protected: // input // + virtual int64_t n_pos_per_token() const; // vision + // when the compute graph is built, it creates the input tensors that it needs // the contents of the input tensors are set by the input_set() function @@ -299,6 +287,8 @@ protected: // graph // + virtual int32_t graph_max_nodes() const; + // zero-out inputs and create the ctx_compute for the compute graph virtual ggml_cgraph * graph_init(); @@ -477,11 +467,11 @@ public: size_t n_token_count) override; protected: - virtual size_t state_get_data(llama_io_write_i & io); - virtual size_t state_set_data(llama_io_read_i & io); + virtual size_t state_write_data(llama_io_write_i & io); + virtual size_t state_read_data (llama_io_read_i & io); - virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id); - virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id); + virtual size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id); + virtual size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id); // // members @@ -625,39 +615,15 @@ protected: ggml_context * ctx0, ggml_cgraph * gf) override; - // ======================================================= - // === encoder-decoder === - // - // TODO: this is temporary here, it will be moved - // - - // whether we are computing encoder output or decoder output - bool is_encoding = false; - - // output of the encoder part of the encoder-decoder models - std::vector embd_enc; - std::vector> seq_ids_enc; - - struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] - struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch] - - ggml_tensor * build_inp_embd_enc( - ggml_context * ctx0) override; - - ggml_tensor * build_inp_kq_mask_cross( - ggml_context * ctx0, - int32_t n_tokens) override; - // ====================================================== - // // state save/load // - size_t state_get_data(llama_io_write_i & io) override; - size_t state_set_data(llama_io_read_i & io) override; + size_t state_write_data(llama_io_write_i & io) override; + size_t state_read_data (llama_io_read_i & io) override; - size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; - size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; + size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) override; + size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) override; private: // @@ -767,11 +733,11 @@ protected: // state save/load // - size_t state_get_data(llama_io_write_i & io) override; - size_t state_set_data(llama_io_read_i & io) override; + size_t state_write_data(llama_io_write_i & io) override; + size_t state_read_data (llama_io_read_i & io) override; - size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; - size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; + size_t state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) override; + size_t state_seq_read_data (llama_io_read_i & io, llama_seq_id seq_id) override; private: // @@ -782,21 +748,206 @@ private: llama_kv_cache_recurrent kv_self; }; +// TODO: tmp - need something better +struct llama_cross { + int32_t n_outputs; + float * embd_enc; + + std::vector> seq_ids_enc; +}; + class llama_context_enc : public llama_context_base { public: using llama_context_base::llama_context_base; + + int encode(llama_batch & inp_batch) override; + + llama_cross * cross = nullptr; }; -class llama_context_enc_dec : public llama_context_enc { +class llama_context_dec : public llama_context_kv_self { +public: + using llama_context_kv_self::llama_context_kv_self; + +protected: + void reserve() override; + + // + // input + // + + void input_set(const llama_ubatch & ubatch) override; + +private: + struct { + ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc] + ggml_tensor * cross_kq_mask; // F32 [n_outputs_enc, n_batch] + ggml_tensor * cross_kq_mask_cnv; // F32 [n_outputs_enc, n_batch] + } inp; + +protected: + // + // graph + // + + ggml_cgraph * graph_init() override; + + ggml_tensor * build_inp_cross_embd( + ggml_context * ctx0) override; + + void build_attn_inp( + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa) override; + + ggml_tensor * build_attn_cross( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + float kq_scale, + int il) override; + +public: + llama_cross * cross = nullptr; +}; + +class llama_context_enc_dec : public llama_context_i { public: llama_context_enc_dec( const llama_model & model, llama_context_params params); - virtual ~llama_context_enc_dec(); + ~llama_context_enc_dec(); + + void init() override; + + void synchronize() override; + + const llama_model & get_model() const override; + + // TODO: the default implementation of these getters calls the corresponding getter of the enc or dec context + // in the future, the public API in llama.h should allow to get references to the context that the user wants + // this will allow to specify the desired context explicitly + // for example: + // + // // this can be an enc-dec context + // llama_context_t ctx = llama_init_from_model(...); + // + // ... + // + // llama_context_t ctx_enc = llama_get_ctx_enc(ctx); + // llama_set_embeddings(ctx_enc, true); + // + // llama_context_t ctx_dec = llama_get_ctx_dec(ctx); + // llama_set_causal_attn(ctx_dec, true); + // + uint32_t n_ctx() const override; + uint32_t n_ctx_per_seq() const override; + uint32_t n_batch() const override; + uint32_t n_ubatch() const override; + uint32_t n_seq_max() const override; + + uint32_t n_threads() const override; + uint32_t n_threads_batch() const override; + + llama_kv_cache * get_kv_self() override; + const llama_kv_cache * get_kv_self() const override; + + void kv_self_update() override; + + enum llama_pooling_type pooling_type() const override; + + float * get_logits() override; + float * get_logits_ith(int32_t i) override; + + float * get_embeddings() override; + float * get_embeddings_ith(int32_t i) override; + float * get_embeddings_seq(llama_seq_id seq_id) override; + + void attach_threadpool( + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch) override; + + void detach_threadpool() override; + + void set_n_threads(int32_t n_threads, int32_t n_threads_batch) override; + + void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data) override; + + void set_embeddings (bool value) override; + void set_causal_attn(bool value) override; + + void set_adapter_lora( + llama_adapter_lora * adapter, + float scale) override; + + bool rm_adapter_lora( + llama_adapter_lora * adapter) override; + + void clear_adapter_lora() override; + + bool apply_adapter_cvec( + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) override; + + int encode(llama_batch & inp_batch) override; + int decode(llama_batch & inp_batch) override; + + // + // perf + // + + llama_perf_context_data perf_get_data() const override; + void perf_reset() override; + + // + // state save/load + // + + size_t state_get_size() override; + size_t state_get_data( uint8_t * dst, size_t size) override; + size_t state_set_data(const uint8_t * src, size_t size) override; + + size_t state_seq_get_size(llama_seq_id seq_id) override; + size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override; + size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override; + + bool state_load_file( + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out) override; + + bool state_save_file( + const char * filepath, + const llama_token * tokens, + size_t n_token_count) override; + + size_t state_seq_load_file( + llama_seq_id seq_id, + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out) override; + + size_t state_seq_save_file( + llama_seq_id seq_id, + const char * filepath, + const llama_token * tokens, + size_t n_token_count) override; private: - llama_context_kv_self ctx_dec; + std::unique_ptr ctx_enc; + std::unique_ptr ctx_dec; + + llama_cross cross; }; // For internal test use diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 99eb326205..1e336e844a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -26,7 +26,29 @@ ggml_tensor * llama_graph_i::build_attn( return nullptr; } -ggml_tensor * llama_graph_i::build_inp_embd_enc( +ggml_tensor * llama_graph_i::build_attn_cross( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + float kq_scale, + int il) { + GGML_UNUSED(ctx0); + GGML_UNUSED(gf); + GGML_UNUSED(q_cur); + GGML_UNUSED(k_cur); + GGML_UNUSED(v_cur); + GGML_UNUSED(kq_b); + GGML_UNUSED(kq_scale); + GGML_UNUSED(il); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + return nullptr; +} + +ggml_tensor * llama_graph_i::build_inp_cross_embd( ggml_context * ctx0) { GGML_UNUSED(ctx0); @@ -34,7 +56,7 @@ ggml_tensor * llama_graph_i::build_inp_embd_enc( return nullptr; } -ggml_tensor * llama_graph_i::build_inp_kq_mask_cross( +ggml_tensor * llama_graph_i::build_inp_cross_kq_mask( ggml_context * ctx0, int32_t n_tokens) { GGML_UNUSED(ctx0); diff --git a/src/llama-graph.h b/src/llama-graph.h index c84c254934..28e8a56306 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -114,10 +114,20 @@ public: float kq_scale, int il); - virtual ggml_tensor * build_inp_embd_enc( + virtual ggml_tensor * build_attn_cross( + ggml_context * ctx0, + ggml_cgraph * gf, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + float kq_scale, + int il); + + virtual ggml_tensor * build_inp_cross_embd( ggml_context * ctx0); - virtual ggml_tensor * build_inp_kq_mask_cross( + virtual ggml_tensor * build_inp_cross_kq_mask( ggml_context * ctx0, int32_t n_tokens); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e8057f4687..38e8c2812f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3964,16 +3964,16 @@ struct llm_build_context { } // TODO: tmp - struct ggml_tensor * build_inp_embd_enc() { - ggml_tensor * cur = lgf->build_inp_embd_enc(ctx0); + struct ggml_tensor * build_inp_cross_embd() { + ggml_tensor * cur = lgf->build_inp_cross_embd(ctx0); cb(cur, "embd_enc", -1); return cur; } // TODO: tmp - struct ggml_tensor * build_inp_kq_mask_cross() { - ggml_tensor * cur = lgf->build_inp_kq_mask_cross(ctx0, n_tokens); + struct ggml_tensor * build_inp_cross_kq_mask() { + ggml_tensor * cur = lgf->build_inp_cross_kq_mask(ctx0, n_tokens); cb(cur, "KQ_mask_cross", -1); return cur; @@ -4294,6 +4294,42 @@ struct llm_build_context { return cur; } + struct ggml_tensor * build_attn_cross( + struct ggml_cgraph * gf, + struct ggml_tensor * wo, + struct ggml_tensor * wo_b, + struct ggml_tensor * q_cur, + struct ggml_tensor * k_cur, + struct ggml_tensor * v_cur, + int32_t n_tokens, // TODO: remove + float kq_scale, + int il) { + GGML_UNUSED(n_tokens); + + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + ggml_tensor * cur = lgf->build_attn_cross(ctx0, gf, q_cur, k_cur, v_cur, nullptr, kq_scale, il); + cb(cur, "kqv_out", il); + + if (wo) { + cur = lgf->build_lora_mm(ctx0, wo, cur); + } + + if (wo_b) { + //cb(cur, "kqv_wo", il); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; + } + struct ggml_tensor * build_attn_with_kq_b( struct ggml_cgraph * gf, struct ggml_tensor * wo, @@ -9762,209 +9798,173 @@ struct llm_build_context { ggml_build_forward_expand(gf, cur); } - //void build_t5_dec(ggml_cgraph * gf) { - // const int64_t n_embd_head = hparams.n_embd_head_v; - // const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + void build_t5_dec(ggml_cgraph * gf) { + const int64_t n_embd_head = hparams.n_embd_head_v; + //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - // GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // struct ggml_tensor * cur; - // struct ggml_tensor * inpL; + struct ggml_tensor * cur; + struct ggml_tensor * inpL; - // inpL = build_inp_embd(model.tok_embd); + inpL = build_inp_embd(model.tok_embd); - // GGML_ASSERT(!lctx.is_encoding); - // GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first"); + struct ggml_tensor * embd_enc = build_inp_cross_embd(); + struct ggml_tensor * pos_bucket_dec = build_pos_bucket(); - // struct ggml_tensor * embd_enc = build_inp_embd_enc(); - // struct ggml_tensor * pos_bucket_dec = build_pos_bucket(true); + const int64_t n_outputs_enc = embd_enc->ne[1]; - // struct ggml_tensor * KQ_mask_dec = build_inp_kq_mask(); - // struct ggml_tensor * KQ_mask_cross = build_inp_kq_mask_cross(); + lgf->build_attn_inp(ctx0, n_tokens, true, false); - // for (int il = 0; il < n_layer; ++il) { - // struct ggml_tensor * inpSA = inpL; + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; - // // norm - // cur = build_norm(inpL, - // model.layers[il].attn_norm, NULL, - // LLM_NORM_RMS, il); - // cb(cur, "attn_norm", il); + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); - // // self-attention - // { - // struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - // cb(Qcur, "Qcur", il); + // self-attention + { + struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); - // struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - // cb(Kcur, "Kcur", il); + struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); - // struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - // cb(Vcur, "Vcur", il); + struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); - // build_kv_store(gf, Kcur, Vcur, il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // struct ggml_tensor * k = - // ggml_view_3d(ctx0, kv_self.k_l[il], - // n_embd_head_k, n_kv, n_head_kv, - // ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), - // ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - // 0); - // cb(k, "k", il); + struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; + struct ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); - // struct ggml_tensor * v = - // ggml_view_3d(ctx0, kv_self.v_l[il], - // n_kv, n_embd_head_v, n_head_kv, - // ggml_element_size(kv_self.v_l[il])*n_ctx, - // ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v, - // 0); - // cb(v, "v", il); + cur = build_attn_with_kq_b(gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, kq_b, n_tokens, 1.0f, il); + cb(cur, "kqv_out", il); + } - // Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "cross_inp", il); - // struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * inpCA = cur; - // struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - // cb(kq, "kq", il); + // norm + cur = build_norm(cur, + model.layers[il].attn_norm_cross, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm_cross", il); - // struct ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; - // struct ggml_tensor * pos_bias = build_pos_bias(pos_bucket_dec, attn_rel_b); - // struct ggml_tensor * kq_b = ggml_add(ctx0, kq, pos_bias); - // cb(kq_b, "kq_b", il); + // cross-attention + { + struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur); + cb(Qcur, "Qcur", il); - // kq = ggml_soft_max_ext(ctx0, kq_b, KQ_mask_dec, 1.0f, hparams.f_max_alibi_bias); - // cb(kq, "kq_soft_max_ext", il); + struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc); + cb(Kcur, "Kcur", il); - // struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - // cb(kqv, "kqv", il); + struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc); + cb(Vcur, "Vcur", il); - // struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - // cb(kqv_merged, "kqv_merged", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); - // cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); - // cb(cur, "kqv_merged_cont", il); + cur = build_attn_cross(gf, + model.layers[il].wo_cross, nullptr, + Qcur, Kcur, Vcur, n_tokens, 1.0f, il); + cb(cur, "kqv_out", il); - // ggml_build_forward_expand(gf, cur); + //struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + //struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); - // cur = build_lora_mm(model.layers[il].wo, cur); - // cb(cur, "kqv_out", il); - // } + //struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + //cb(kq, "kq", il); - // cur = ggml_add(ctx0, cur, inpSA); - // cb(cur, "cross_inp", il); + //kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias); + //cb(kq, "kq_soft_max_ext", il); - // struct ggml_tensor * inpCA = cur; + //struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc))); + //cb(v, "v", il); - // // norm - // cur = build_norm(cur, - // model.layers[il].attn_norm_cross, NULL, - // LLM_NORM_RMS, il); - // cb(cur, "attn_norm_cross", il); + //struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq); + //cb(kqv, "kqv", il); - // // cross-attention - // { - // struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur); - // cb(Qcur, "Qcur", il); + //struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + //cb(kqv_merged, "kqv_merged", il); - // struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc); - // cb(Kcur, "Kcur", il); + //cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + //cb(cur, "kqv_merged_cont", il); - // struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc); - // cb(Vcur, "Vcur", il); + //ggml_build_forward_expand(gf, cur); - // Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - // Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); + //cur = build_lora_mm(model.layers[il].wo_cross, cur); + //cb(cur, "kqv_out", il); + } - // struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - // struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); + } - // struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - // cb(kq, "kq", il); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA); + cb(ffn_inp, "ffn_inp", il); - // kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias); - // cb(kq, "kq_soft_max_ext", il); + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); - // struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc))); - // cb(v, "v", il); + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } - // struct ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq); - // cb(kqv, "kqv", il); + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); - // struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - // cb(kqv_merged, "kqv_merged", il); + cur = lgf->build_cvec(ctx0, cur, il); + cb(cur, "l_out", il); - // cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); - // cb(cur, "kqv_merged_cont", il); + // input for next layer + inpL = cur; + } - // ggml_build_forward_expand(gf, cur); + cur = inpL; + cb(cur, "result_embd", -1); - // cur = build_lora_mm(model.layers[il].wo_cross, cur); - // cb(cur, "kqv_out", il); - // } + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); - // if (il == n_layer - 1) { - // // skip computing output for unused tokens - // struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - // cur = ggml_get_rows(ctx0, cur, inp_out_ids); - // inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - // inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); - // } + cb(cur, "result_norm", -1); + res.t_embd = cur; - // struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA); - // cb(ffn_inp, "ffn_inp", il); + // lm_head + cur = build_lora_mm(model.output, cur); - // // feed-forward network - // { - // cur = build_norm(ffn_inp, - // model.layers[il].ffn_norm, NULL, - // LLM_NORM_RMS, il); - // cb(cur, "ffn_norm", il); + cb(cur, "result_output", -1); + res.t_logits = cur; - // // T5 uses relu, flan-T5 uses gelu-gated - // cur = build_ffn(cur, - // model.layers[il].ffn_up, NULL, NULL, - // model.layers[il].ffn_gate, NULL, NULL, - // model.layers[il].ffn_down, NULL, NULL, - // NULL, - // model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - // model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, - // il); - // cb(cur, "ffn_out", il); - // } - - // cur = ggml_add(ctx0, cur, ffn_inp); - // cb(cur, "ffn_out", il); - - // ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); - // if (layer_dir != nullptr) { - // cur = ggml_add(ctx0, cur, layer_dir); - // } - // cb(cur, "l_out", il); - - // // input for next layer - // inpL = cur; - // } - - // cur = inpL; - // cb(cur, "result_embd", -1); - - // cur = build_norm(cur, - // model.output_norm, NULL, - // LLM_NORM_RMS, -1); - - // cb(cur, "result_norm", -1); - // res.t_embd = cur; - - // // lm_head - // cur = build_lora_mm(model.output, cur); - - // cb(cur, "result_output", -1); - // res.t_logits = cur; - - // ggml_build_forward_expand(gf, cur); - - // return gf; - //} + ggml_build_forward_expand(gf, cur); + } void build_jais(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -11119,7 +11119,7 @@ llama_graph_result llama_model::build_graph( llm.build_t5_enc(gf); break; case LLAMA_GRAPH_TYPE_DECODER: - //llm.build_t5_dec(gf); + llm.build_t5_dec(gf); break; default: GGML_ABORT("invalid graph type");