mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : support multiple classifier outputs and labels (#13940)
This commit is contained in:
		| @@ -288,9 +288,10 @@ namespace GGUFMeta { | ||||
|  | ||||
|     template<typename T> | ||||
|     bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) { | ||||
|         const int kid = gguf_find_key(meta.get(), key.c_str()); | ||||
|         const gguf_context * ctx = meta.get(); | ||||
|         const int kid = gguf_find_key(ctx, key.c_str()); | ||||
|  | ||||
|         if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { | ||||
|         if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { | ||||
|             if (required) { | ||||
|                 throw std::runtime_error(format("array key not found in model: %s", key.c_str())); | ||||
|             } | ||||
| @@ -298,28 +299,40 @@ namespace GGUFMeta { | ||||
|         } | ||||
|  | ||||
|         struct GGUFMeta::ArrayInfo arr_info = | ||||
|             GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid); | ||||
|             GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid); | ||||
|  | ||||
|         switch (arr_info.gt) { | ||||
|             case GGUF_TYPE_UINT32: | ||||
|             case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,  int32_t>::value) || | ||||
|                                                 (std::is_same<T, uint32_t>::value)); break; | ||||
|             case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,    float>::value)); break; | ||||
|             case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,     int32_t>::value) || | ||||
|                                                 (std::is_same<T,    uint32_t>::value)); break; | ||||
|             case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,       float>::value)); break; | ||||
|             case GGUF_TYPE_STRING:  GGML_ASSERT((std::is_same<T, std::string>::value)); break; | ||||
|             default: | ||||
|                 throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); | ||||
|                 throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str())); | ||||
|         } | ||||
|  | ||||
|         result.resize(arr_info.length); | ||||
|         result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); | ||||
|         if constexpr (std::is_same<T, std::string>::value) { | ||||
|             const size_t n_items = gguf_get_arr_n(ctx, kid); | ||||
|             result.clear(); | ||||
|  | ||||
|             for (size_t i = 0; i < n_items; i++) { | ||||
|                 const T value = gguf_get_arr_str(ctx, kid, i); | ||||
|                 result.emplace_back(value); | ||||
|             } | ||||
|         } else { | ||||
|             result.resize(arr_info.length); | ||||
|             result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); | ||||
|         } | ||||
|  | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     template<typename T, size_t N_MAX> | ||||
|     bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) { | ||||
|         const int kid = gguf_find_key(meta.get(), key.c_str()); | ||||
|         const gguf_context * ctx = meta.get(); | ||||
|         const int kid = gguf_find_key(ctx, key.c_str()); | ||||
|  | ||||
|         if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { | ||||
|         if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { | ||||
|             if (required) { | ||||
|                 throw std::runtime_error(format("array key not found in model: %s", key.c_str())); | ||||
|             } | ||||
| @@ -327,22 +340,32 @@ namespace GGUFMeta { | ||||
|         } | ||||
|  | ||||
|         struct GGUFMeta::ArrayInfo arr_info = | ||||
|             GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid); | ||||
|             GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid); | ||||
|  | ||||
|         switch (arr_info.gt) { | ||||
|             case GGUF_TYPE_UINT32: | ||||
|             case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,  int32_t>::value) || | ||||
|                                                 (std::is_same<T, uint32_t>::value)); break; | ||||
|             case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,    float>::value)); break; | ||||
|             case GGUF_TYPE_INT32:   GGML_ASSERT((std::is_same<T,     int32_t>::value) || | ||||
|                                                 (std::is_same<T,    uint32_t>::value)); break; | ||||
|             case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T,       float>::value)); break; | ||||
|             case GGUF_TYPE_STRING:  GGML_ASSERT((std::is_same<T, std::string>::value)); break; | ||||
|             default: | ||||
|                 throw std::runtime_error(format("%s is not a float32/uint32/int32 array", key.c_str())); | ||||
|                 throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str())); | ||||
|         } | ||||
|  | ||||
|         if (arr_info.length > N_MAX) { | ||||
|             throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); | ||||
|         } | ||||
|  | ||||
|         std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); | ||||
|         if constexpr (std::is_same<T, std::string>::value) { | ||||
|             const size_t n_items = gguf_get_arr_n(ctx, kid); | ||||
|  | ||||
|             for (size_t i = 0; i < n_items; i++) { | ||||
|                 const T value = gguf_get_arr_str(ctx, kid, i); | ||||
|                 result[i] = value; | ||||
|             } | ||||
|         } else { | ||||
|             std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); | ||||
|         } | ||||
|  | ||||
|         return true; | ||||
|     } | ||||
| @@ -352,6 +375,8 @@ namespace GGUFMeta { | ||||
|         return get_arr(llm_kv(kid), result, required); | ||||
|     } | ||||
|  | ||||
|     template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required); | ||||
|  | ||||
|     template<typename T> | ||||
|     bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { | ||||
|         auto it = kv_overrides.find(key); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Sigbjørn Skjæret
					Sigbjørn Skjæret