mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	ggml: refactor cross entropy loss CPU impl. (ggml/976)
This commit is contained in:
		 Johannes Gäßler
					Johannes Gäßler
				
			
				
					committed by
					
						 Georgi Gerganov
						Georgi Gerganov
					
				
			
			
				
	
			
			
			 Georgi Gerganov
						Georgi Gerganov
					
				
			
						parent
						
							5d5ab1e5cc
						
					
				
				
					commit
					eee39bdc96
				
			| @@ -247,7 +247,7 @@ extern "C" { | |||||||
|     GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched); |     GGML_API void                 ggml_backend_sched_free(ggml_backend_sched_t sched); | ||||||
|  |  | ||||||
|     // Initialize backend buffers from a measure graph |     // Initialize backend buffers from a measure graph | ||||||
|     GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); |     GGML_API bool                 ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success | ||||||
|  |  | ||||||
|     GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched); |     GGML_API int                  ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched); | ||||||
|     GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i); |     GGML_API ggml_backend_t       ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i); | ||||||
| @@ -262,7 +262,7 @@ extern "C" { | |||||||
|     GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); |     GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); | ||||||
|  |  | ||||||
|     // Allocate and compute graph on the backend scheduler |     // Allocate and compute graph on the backend scheduler | ||||||
|     GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); |     GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success | ||||||
|     GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); |     GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); | ||||||
|     GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph); |     GGML_API enum ggml_status     ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph); | ||||||
|     GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched); |     GGML_API void                 ggml_backend_sched_synchronize(ggml_backend_sched_t sched); | ||||||
|   | |||||||
| @@ -4232,9 +4232,13 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa | |||||||
| } | } | ||||||
|  |  | ||||||
| struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { | struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { | ||||||
|  |     if (ggml_is_empty(tensor)) { | ||||||
|  |         return tensor; | ||||||
|  |     } | ||||||
|     if (tensor->buffer) { |     if (tensor->buffer) { | ||||||
|         ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor)); |         ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor)); | ||||||
|     } else { |     } else { | ||||||
|  |         GGML_ASSERT(tensor->data); | ||||||
|         memset(tensor->data, 0, ggml_nbytes(tensor)); |         memset(tensor->data, 0, ggml_nbytes(tensor)); | ||||||
|     } |     } | ||||||
|     return tensor; |     return tensor; | ||||||
| @@ -16851,41 +16855,40 @@ static void ggml_compute_forward_cross_entropy_loss_f32( | |||||||
|     const struct ggml_tensor * src0 = dst->src[0]; |     const struct ggml_tensor * src0 = dst->src[0]; | ||||||
|     const struct ggml_tensor * src1 = dst->src[1]; |     const struct ggml_tensor * src1 = dst->src[1]; | ||||||
|  |  | ||||||
|     GGML_ASSERT(ggml_is_contiguous(src0)); |     GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||||||
|     GGML_ASSERT(ggml_is_contiguous(src1)); |     GGML_ASSERT(src1->type == GGML_TYPE_F32); | ||||||
|     GGML_ASSERT(ggml_is_scalar(dst)); |     GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); | ||||||
|  |     GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); | ||||||
|     GGML_ASSERT(ggml_are_same_shape(src0, src1)); |     GGML_ASSERT(ggml_are_same_shape(src0, src1)); | ||||||
|  |     GGML_ASSERT(ggml_is_scalar(dst)); | ||||||
|  |     GGML_ASSERT(dst->type == GGML_TYPE_F32); | ||||||
|  |  | ||||||
|  |     // TODO: handle transposed/permuted matrices | ||||||
|  |     const int64_t nc = src0->ne[0]; | ||||||
|  |     const int64_t nr = ggml_nrows(src0); | ||||||
|  |  | ||||||
|     const int ith = params->ith; |     const int ith = params->ith; | ||||||
|     const int nth = params->nth; |     const int nth = params->nth; | ||||||
|  |  | ||||||
|     float * sums = (float *) params->wdata; |     float * sums =  (float *) params->wdata; | ||||||
|  |     float * st   = ((float *) params->wdata) + nth + ith*nc; | ||||||
|     // TODO: handle transposed/permuted matrices |     float sum_thread = 0.0f; | ||||||
|     const int nc = src0->ne[0]; |  | ||||||
|     const int nr = ggml_nrows(src0); |  | ||||||
|  |  | ||||||
|     GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); |     GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc)); | ||||||
|  |  | ||||||
|     if (ith == 0) { |  | ||||||
|         memset(sums, 0, sizeof(float) * (nth + nth * nc)); |  | ||||||
|     } |  | ||||||
|     ggml_barrier(params->threadpool); |  | ||||||
|  |  | ||||||
|     // rows per thread |     // rows per thread | ||||||
|     const int dr = (nr + nth - 1)/nth; |     const int64_t dr = (nr + nth - 1)/nth; | ||||||
|  |  | ||||||
|     // row range for this thread |     // row range for this thread | ||||||
|     const int ir0 = dr*ith; |     const int64_t ir0 = dr*ith; | ||||||
|     const int ir1 = MIN(ir0 + dr, nr); |     const int64_t ir1 = MIN(ir0 + dr, nr); | ||||||
|  |  | ||||||
|     for (int i1 = ir0; i1 < ir1; i1++) { |     for (int64_t i1 = ir0; i1 < ir1; ++i1) { | ||||||
|         float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]); |         const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]); | ||||||
|         float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]); |         const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]); | ||||||
|         float * st = ((float *) params->wdata) + nth + ith*nc; |  | ||||||
|  |  | ||||||
| #ifndef NDEBUG | #ifndef NDEBUG | ||||||
|         for (int i = 0; i < nc; ++i) { |         for (int64_t i = 0; i < nc; ++i) { | ||||||
|             //printf("p[%d] = %f\n", i, p[i]); |             //printf("p[%d] = %f\n", i, p[i]); | ||||||
|             assert(!isnan(s0[i])); |             assert(!isnan(s0[i])); | ||||||
|             assert(!isnan(s1[i])); |             assert(!isnan(s1[i])); | ||||||
| @@ -16894,23 +16897,24 @@ static void ggml_compute_forward_cross_entropy_loss_f32( | |||||||
|  |  | ||||||
|         float max = -INFINITY; |         float max = -INFINITY; | ||||||
|         ggml_vec_max_f32(nc, &max, s0); |         ggml_vec_max_f32(nc, &max, s0); | ||||||
|         ggml_float sum = ggml_vec_log_soft_max_f32(nc, st, s0, max); |         const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max); | ||||||
|         assert(sum >= 0.0); |         assert(sum_softmax >= 0.0); | ||||||
|  |  | ||||||
|         ggml_vec_add1_f32(nc, st, st, -sum); |         ggml_vec_add1_f32(nc, st, st, -sum_softmax); | ||||||
|         ggml_vec_mul_f32(nc, st, st, s1); |         ggml_vec_mul_f32(nc, st, st, s1); | ||||||
|  |  | ||||||
|         float st_sum = 0.0f; |         float sum_st = 0.0f; | ||||||
|         ggml_vec_sum_f32(nc, &st_sum, st); |         ggml_vec_sum_f32(nc, &sum_st, st); | ||||||
|         sums[ith] += st_sum; |         sum_thread += sum_st; | ||||||
|  |  | ||||||
| #ifndef NDEBUG | #ifndef NDEBUG | ||||||
|         for (int i = 0; i < nc; ++i) { |         for (int64_t i = 0; i < nc; ++i) { | ||||||
|             assert(!isnan(st[i])); |             assert(!isnan(st[i])); | ||||||
|             assert(!isinf(st[i])); |             assert(!isinf(st[i])); | ||||||
|         } |         } | ||||||
| #endif | #endif | ||||||
|     } |     } | ||||||
|  |     sums[ith] = sum_thread; | ||||||
|     ggml_barrier(params->threadpool); |     ggml_barrier(params->threadpool); | ||||||
|  |  | ||||||
|     if (ith == 0) { |     if (ith == 0) { | ||||||
| @@ -16976,7 +16980,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( | |||||||
|         float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]); |         float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]); | ||||||
|  |  | ||||||
| #ifndef NDEBUG | #ifndef NDEBUG | ||||||
|         for (int i = 0; i < nc; ++i) { |         for (int64_t i = 0; i < nc; ++i) { | ||||||
|             //printf("p[%d] = %f\n", i, p[i]); |             //printf("p[%d] = %f\n", i, p[i]); | ||||||
|             assert(!isnan(s0[i])); |             assert(!isnan(s0[i])); | ||||||
|             assert(!isnan(s1[i])); |             assert(!isnan(s1[i])); | ||||||
| @@ -16995,7 +16999,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( | |||||||
|         ggml_vec_scale_f32(nc, ds0, d_by_nr); |         ggml_vec_scale_f32(nc, ds0, d_by_nr); | ||||||
|  |  | ||||||
| #ifndef NDEBUG | #ifndef NDEBUG | ||||||
|         for (int i = 0; i < nc; ++i) { |         for (int64_t i = 0; i < nc; ++i) { | ||||||
|             assert(!isnan(ds0[i])); |             assert(!isnan(ds0[i])); | ||||||
|             assert(!isinf(ds0[i])); |             assert(!isinf(ds0[i])); | ||||||
|         } |         } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user