mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	embedding: assign n_ubatch value, print error on n_batch overflow
				
					
				
			This commit is contained in:
		| @@ -61,6 +61,8 @@ int main(int argc, char ** argv) { | ||||
|     } | ||||
|  | ||||
|     params.embedding = true; | ||||
|     // For BERT models, batch size must be equal to ubatch size | ||||
|     params.n_ubatch = params.n_batch; | ||||
|  | ||||
|     print_build_info(); | ||||
|  | ||||
| @@ -114,7 +116,9 @@ int main(int argc, char ** argv) { | ||||
|     for (const auto & prompt : prompts) { | ||||
|         auto inp = ::llama_tokenize(ctx, prompt, true, false); | ||||
|         if (inp.size() > n_batch) { | ||||
|             inp.resize(n_batch); | ||||
|             fprintf(stderr, "%s: error: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", | ||||
|                     __func__, (long long int) inp.size(), (long long int) n_batch); | ||||
|             return 1; | ||||
|         } | ||||
|         inputs.push_back(inp); | ||||
|     } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Minsoo Cheong
					Minsoo Cheong