mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	ggml : implement soft_max_ext (CPU)
This commit is contained in:
		
							
								
								
									
										52
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								ggml.c
									
									
									
									
									
								
							@@ -4829,7 +4829,9 @@ static struct ggml_tensor * ggml_soft_max_impl(
 | 
				
			|||||||
        struct ggml_tensor  * mask,
 | 
					        struct ggml_tensor  * mask,
 | 
				
			||||||
        float                 scale,
 | 
					        float                 scale,
 | 
				
			||||||
        bool                  inplace) {
 | 
					        bool                  inplace) {
 | 
				
			||||||
 | 
					    GGML_ASSERT(ggml_is_contiguous(a));
 | 
				
			||||||
    if (mask) {
 | 
					    if (mask) {
 | 
				
			||||||
 | 
					        GGML_ASSERT(ggml_is_contiguous(mask));
 | 
				
			||||||
        GGML_ASSERT(mask->ne[2] == 1);
 | 
					        GGML_ASSERT(mask->ne[2] == 1);
 | 
				
			||||||
        GGML_ASSERT(mask->ne[3] == 1);
 | 
					        GGML_ASSERT(mask->ne[3] == 1);
 | 
				
			||||||
        GGML_ASSERT(ggml_can_repeat_rows(mask, a));
 | 
					        GGML_ASSERT(ggml_can_repeat_rows(mask, a));
 | 
				
			||||||
@@ -10571,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
 | 
				
			|||||||
static void ggml_compute_forward_soft_max_f32(
 | 
					static void ggml_compute_forward_soft_max_f32(
 | 
				
			||||||
        const struct ggml_compute_params * params,
 | 
					        const struct ggml_compute_params * params,
 | 
				
			||||||
        const struct ggml_tensor * src0,
 | 
					        const struct ggml_tensor * src0,
 | 
				
			||||||
        struct ggml_tensor * dst) {
 | 
					        const struct ggml_tensor * src1,
 | 
				
			||||||
    GGML_ASSERT(ggml_is_contiguous(src0));
 | 
					              struct ggml_tensor * dst) {
 | 
				
			||||||
    GGML_ASSERT(ggml_is_contiguous(dst));
 | 
					    assert(ggml_is_contiguous(dst));
 | 
				
			||||||
    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 | 
					    assert(ggml_are_same_shape(src0, dst));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
 | 
					    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
 | 
				
			||||||
        return;
 | 
					        return;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float scale = 1.0f;
 | 
				
			||||||
 | 
					    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO: handle transposed/permuted matrices
 | 
					    // TODO: handle transposed/permuted matrices
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int ith = params->ith;
 | 
					    const int ith = params->ith;
 | 
				
			||||||
    const int nth = params->nth;
 | 
					    const int nth = params->nth;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const int64_t ne11 = src1 ? src1->ne[1] : 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int nc = src0->ne[0];
 | 
					    const int nc = src0->ne[0];
 | 
				
			||||||
    const int nr = ggml_nrows(src0);
 | 
					    const int nr = ggml_nrows(src0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -10595,29 +10602,39 @@ static void ggml_compute_forward_soft_max_f32(
 | 
				
			|||||||
    const int ir0 = dr*ith;
 | 
					    const int ir0 = dr*ith;
 | 
				
			||||||
    const int ir1 = MIN(ir0 + dr, nr);
 | 
					    const int ir1 = MIN(ir0 + dr, nr);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float * wdata = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (int i1 = ir0; i1 < ir1; i1++) {
 | 
					    for (int i1 = ir0; i1 < ir1; i1++) {
 | 
				
			||||||
        float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
 | 
					        float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
 | 
				
			||||||
        float *dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
 | 
					        float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // broadcast the mask across rows
 | 
				
			||||||
 | 
					        float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        float * wp = wdata;
 | 
				
			||||||
 | 
					        for (int i = 0; i < nc; i++) {
 | 
				
			||||||
 | 
					            wp[i] = sp[i]*scale + (mp ? mp[i] : 0.0f);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifndef NDEBUG
 | 
					#ifndef NDEBUG
 | 
				
			||||||
        for (int i = 0; i < nc; ++i) {
 | 
					        for (int i = 0; i < nc; ++i) {
 | 
				
			||||||
            //printf("p[%d] = %f\n", i, p[i]);
 | 
					            //printf("p[%d] = %f\n", i, p[i]);
 | 
				
			||||||
            assert(!isnan(sp[i]));
 | 
					            assert(!isnan(wp[i]));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        float max = -INFINITY;
 | 
					        float max = -INFINITY;
 | 
				
			||||||
        ggml_vec_max_f32(nc, &max, sp);
 | 
					        ggml_vec_max_f32(nc, &max, wp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ggml_float sum = 0.0;
 | 
					        ggml_float sum = 0.0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        uint16_t scvt;
 | 
					        uint16_t scvt;
 | 
				
			||||||
        for (int i = 0; i < nc; i++) {
 | 
					        for (int i = 0; i < nc; i++) {
 | 
				
			||||||
            if (sp[i] == -INFINITY) {
 | 
					            if (wp[i] == -INFINITY) {
 | 
				
			||||||
                dp[i] = 0.0f;
 | 
					                dp[i] = 0.0f;
 | 
				
			||||||
            } else {
 | 
					            } else {
 | 
				
			||||||
                // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
 | 
					                // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
 | 
				
			||||||
                ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
 | 
					                ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
 | 
				
			||||||
                memcpy(&scvt, &s, sizeof(scvt));
 | 
					                memcpy(&scvt, &s, sizeof(scvt));
 | 
				
			||||||
                const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
 | 
					                const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
 | 
				
			||||||
                sum += (ggml_float)val;
 | 
					                sum += (ggml_float)val;
 | 
				
			||||||
@@ -10642,11 +10659,12 @@ static void ggml_compute_forward_soft_max_f32(
 | 
				
			|||||||
static void ggml_compute_forward_soft_max(
 | 
					static void ggml_compute_forward_soft_max(
 | 
				
			||||||
        const struct ggml_compute_params * params,
 | 
					        const struct ggml_compute_params * params,
 | 
				
			||||||
        const struct ggml_tensor * src0,
 | 
					        const struct ggml_tensor * src0,
 | 
				
			||||||
        struct ggml_tensor * dst) {
 | 
					        const struct ggml_tensor * src1,
 | 
				
			||||||
 | 
					              struct ggml_tensor * dst) {
 | 
				
			||||||
    switch (src0->type) {
 | 
					    switch (src0->type) {
 | 
				
			||||||
        case GGML_TYPE_F32:
 | 
					        case GGML_TYPE_F32:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                ggml_compute_forward_soft_max_f32(params, src0, dst);
 | 
					                ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
 | 
				
			||||||
            } break;
 | 
					            } break;
 | 
				
			||||||
        default:
 | 
					        default:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
@@ -13883,7 +13901,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
 | 
				
			|||||||
            } break;
 | 
					            } break;
 | 
				
			||||||
        case GGML_OP_SOFT_MAX:
 | 
					        case GGML_OP_SOFT_MAX:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
 | 
					                ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
 | 
				
			||||||
            } break;
 | 
					            } break;
 | 
				
			||||||
        case GGML_OP_SOFT_MAX_BACK:
 | 
					        case GGML_OP_SOFT_MAX_BACK:
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
@@ -15919,6 +15937,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
 | 
				
			|||||||
                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
 | 
					                        cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                } break;
 | 
					                } break;
 | 
				
			||||||
 | 
					            case GGML_OP_SOFT_MAX:
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    n_tasks = MIN(n_threads, ggml_nrows(node->src[0]));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
 | 
				
			||||||
 | 
					                } break;
 | 
				
			||||||
            case GGML_OP_CONV_TRANSPOSE_1D:
 | 
					            case GGML_OP_CONV_TRANSPOSE_1D:
 | 
				
			||||||
                {
 | 
					                {
 | 
				
			||||||
                    GGML_ASSERT(node->src[0]->ne[3] == 1);
 | 
					                    GGML_ASSERT(node->src[0]->ne[3] == 1);
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user