mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	android : adapt to new API
This commit is contained in:
		@@ -125,7 +125,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
 | 
			
		||||
    ctx_params.n_threads       = n_threads;
 | 
			
		||||
    ctx_params.n_threads_batch = n_threads;
 | 
			
		||||
 | 
			
		||||
    llama_context * context = llama_new_context_with_model(model, ctx_params);
 | 
			
		||||
    llama_context * context = llama_init_from_model(model, ctx_params);
 | 
			
		||||
 | 
			
		||||
    if (!context) {
 | 
			
		||||
        LOGe("llama_new_context_with_model() returned null)");
 | 
			
		||||
@@ -175,7 +175,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
 | 
			
		||||
 | 
			
		||||
    const auto context = reinterpret_cast<llama_context *>(context_pointer);
 | 
			
		||||
    const auto model = reinterpret_cast<llama_model *>(model_pointer);
 | 
			
		||||
    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
 | 
			
		||||
    const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
 | 
			
		||||
 | 
			
		||||
    const int n_ctx = llama_n_ctx(context);
 | 
			
		||||
 | 
			
		||||
@@ -186,19 +186,20 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
 | 
			
		||||
    for (nri = 0; nri < nr; nri++) {
 | 
			
		||||
        LOGi("Benchmark prompt processing (pp)");
 | 
			
		||||
 | 
			
		||||
        common_batch_clear(*batch);
 | 
			
		||||
        llama_batch_ext_clear(batch);
 | 
			
		||||
 | 
			
		||||
        const int n_tokens = pp;
 | 
			
		||||
        for (i = 0; i < n_tokens; i++) {
 | 
			
		||||
            common_batch_add(*batch, 0, i, { 0 }, false);
 | 
			
		||||
            llama_seq_id seq_id = 0;
 | 
			
		||||
            llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        batch->logits[batch->n_tokens - 1] = true;
 | 
			
		||||
        llama_batch_ext_set_output_last(batch);
 | 
			
		||||
        llama_kv_self_clear(context);
 | 
			
		||||
 | 
			
		||||
        const auto t_pp_start = ggml_time_us();
 | 
			
		||||
        if (llama_decode(context, *batch) != 0) {
 | 
			
		||||
            LOGi("llama_decode() failed during prompt processing");
 | 
			
		||||
        if (llama_decode_ext(context, batch) != 0) {
 | 
			
		||||
            LOGi("llama_decode_ext() failed during prompt processing");
 | 
			
		||||
        }
 | 
			
		||||
        const auto t_pp_end = ggml_time_us();
 | 
			
		||||
 | 
			
		||||
@@ -210,14 +211,15 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
 | 
			
		||||
        const auto t_tg_start = ggml_time_us();
 | 
			
		||||
        for (i = 0; i < tg; i++) {
 | 
			
		||||
 | 
			
		||||
            common_batch_clear(*batch);
 | 
			
		||||
            llama_batch_ext_clear(batch);
 | 
			
		||||
            for (j = 0; j < pl; j++) {
 | 
			
		||||
                common_batch_add(*batch, 0, i, { j }, true);
 | 
			
		||||
                llama_seq_id seq_id = j;
 | 
			
		||||
                llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, true);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            LOGi("llama_decode() text generation: %d", i);
 | 
			
		||||
            if (llama_decode(context, *batch) != 0) {
 | 
			
		||||
                LOGi("llama_decode() failed during text generation");
 | 
			
		||||
            LOGi("llama_decode_ext() text generation: %d", i);
 | 
			
		||||
            if (llama_decode_ext(context, batch) != 0) {
 | 
			
		||||
                LOGi("llama_decode_ext() failed during text generation");
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -272,32 +274,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
 | 
			
		||||
extern "C"
 | 
			
		||||
JNIEXPORT jlong JNICALL
 | 
			
		||||
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
 | 
			
		||||
 | 
			
		||||
    // Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
 | 
			
		||||
 | 
			
		||||
    llama_batch *batch = new llama_batch {
 | 
			
		||||
        0,
 | 
			
		||||
        nullptr,
 | 
			
		||||
        nullptr,
 | 
			
		||||
        nullptr,
 | 
			
		||||
        nullptr,
 | 
			
		||||
        nullptr,
 | 
			
		||||
        nullptr,
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    if (embd) {
 | 
			
		||||
        batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
 | 
			
		||||
    } else {
 | 
			
		||||
        batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    batch->pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens);
 | 
			
		||||
    batch->n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens);
 | 
			
		||||
    batch->seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
 | 
			
		||||
    for (int i = 0; i < n_tokens; ++i) {
 | 
			
		||||
        batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
 | 
			
		||||
    }
 | 
			
		||||
    batch->logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens);
 | 
			
		||||
    llama_batch_ext * batch = llama_batch_ext_init(n_tokens, n_seq_max);
 | 
			
		||||
 | 
			
		||||
    return reinterpret_cast<jlong>(batch);
 | 
			
		||||
}
 | 
			
		||||
@@ -305,9 +282,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
 | 
			
		||||
extern "C"
 | 
			
		||||
JNIEXPORT void JNICALL
 | 
			
		||||
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
 | 
			
		||||
    //llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
 | 
			
		||||
    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
 | 
			
		||||
    delete batch;
 | 
			
		||||
    llama_batch_ext_free(reinterpret_cast<llama_batch_ext *>(batch_pointer));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
extern "C"
 | 
			
		||||
@@ -355,7 +330,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
 | 
			
		||||
 | 
			
		||||
    const auto text = env->GetStringUTFChars(jtext, 0);
 | 
			
		||||
    const auto context = reinterpret_cast<llama_context *>(context_pointer);
 | 
			
		||||
    const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
 | 
			
		||||
    const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
 | 
			
		||||
 | 
			
		||||
    bool parse_special = (format_chat == JNI_TRUE);
 | 
			
		||||
    const auto tokens_list = common_tokenize(context, text, true, parse_special);
 | 
			
		||||
@@ -363,7 +338,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
 | 
			
		||||
    auto n_ctx = llama_n_ctx(context);
 | 
			
		||||
    auto n_kv_req = tokens_list.size() + n_len;
 | 
			
		||||
 | 
			
		||||
    LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
 | 
			
		||||
    LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", (int) n_len, (int) n_ctx, (int) n_kv_req);
 | 
			
		||||
 | 
			
		||||
    if (n_kv_req > n_ctx) {
 | 
			
		||||
        LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
 | 
			
		||||
@@ -373,23 +348,24 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
 | 
			
		||||
        LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    common_batch_clear(*batch);
 | 
			
		||||
    llama_batch_ext_clear(batch);
 | 
			
		||||
 | 
			
		||||
    // evaluate the initial prompt
 | 
			
		||||
    for (auto i = 0; i < tokens_list.size(); i++) {
 | 
			
		||||
        common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
 | 
			
		||||
        llama_seq_id seq_id = 0;
 | 
			
		||||
        llama_batch_ext_add_text(batch, tokens_list[i], i, &seq_id, 1, false);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // llama_decode will output logits only for the last token of the prompt
 | 
			
		||||
    batch->logits[batch->n_tokens - 1] = true;
 | 
			
		||||
    llama_batch_ext_set_output_last(batch);
 | 
			
		||||
 | 
			
		||||
    if (llama_decode(context, *batch) != 0) {
 | 
			
		||||
        LOGe("llama_decode() failed");
 | 
			
		||||
    if (llama_decode_ext(context, batch) != 0) {
 | 
			
		||||
        LOGe("llama_decode_ext() failed");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    env->ReleaseStringUTFChars(jtext, text);
 | 
			
		||||
 | 
			
		||||
    return batch->n_tokens;
 | 
			
		||||
    return llama_batch_ext_get_n_tokens(batch);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
extern "C"
 | 
			
		||||
@@ -404,7 +380,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
 | 
			
		||||
        jobject intvar_ncur
 | 
			
		||||
) {
 | 
			
		||||
    const auto context = reinterpret_cast<llama_context *>(context_pointer);
 | 
			
		||||
    const auto batch   = reinterpret_cast<llama_batch   *>(batch_pointer);
 | 
			
		||||
    const auto batch   = reinterpret_cast<llama_batch_ext *>(batch_pointer);
 | 
			
		||||
    const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
 | 
			
		||||
    const auto model = llama_get_model(context);
 | 
			
		||||
    const auto vocab = llama_model_get_vocab(model);
 | 
			
		||||
@@ -433,13 +409,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
 | 
			
		||||
        new_token = env->NewStringUTF("");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    common_batch_clear(*batch);
 | 
			
		||||
    common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
 | 
			
		||||
    llama_batch_ext_clear(batch);
 | 
			
		||||
    llama_seq_id seq_id = 0;
 | 
			
		||||
    llama_batch_ext_add_text(batch, new_token_id, n_cur, &seq_id, 1, true);
 | 
			
		||||
 | 
			
		||||
    env->CallVoidMethod(intvar_ncur, la_int_var_inc);
 | 
			
		||||
 | 
			
		||||
    if (llama_decode(context, *batch) != 0) {
 | 
			
		||||
        LOGe("llama_decode() returned null");
 | 
			
		||||
    if (llama_decode_ext(context, batch) != 0) {
 | 
			
		||||
        LOGe("llama_decode_ext() returned null");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return new_token;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user