mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	ggml-opt: fix data corruption (ggml/1022)
This commit is contained in:
		
				
					committed by
					
						
						Georgi Gerganov
					
				
			
			
				
	
			
			
			
						parent
						
							9abe9eeae9
						
					
				
				
					commit
					02e4eaf22f
				
			@@ -5019,8 +5019,10 @@ static void ggml_hash_map_free(struct hash_map * map) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// utility functions to change gradients
 | 
			
		||||
// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
 | 
			
		||||
// else if a is in zero_table, replace a
 | 
			
		||||
// isrc is the index of tensor in cgraph->visited_has_set.keys
 | 
			
		||||
// the corresponding gradient (accumulators) are also at position isrc
 | 
			
		||||
// if tensor has a gradient accumulator, modify that accumulator in-place
 | 
			
		||||
// else if there is no gradient for tensor, set the corresponding value
 | 
			
		||||
// else, just add/subtract/etc. the gradients
 | 
			
		||||
 | 
			
		||||
static void ggml_add_or_set(
 | 
			
		||||
@@ -5028,11 +5030,14 @@ static void ggml_add_or_set(
 | 
			
		||||
        struct ggml_cgraph  * cgraph,
 | 
			
		||||
        size_t                isrc,
 | 
			
		||||
        struct ggml_tensor  * tensor) {
 | 
			
		||||
    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
 | 
			
		||||
    GGML_ASSERT(src);
 | 
			
		||||
    if (cgraph->grads[isrc]) {
 | 
			
		||||
        cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
 | 
			
		||||
        cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]);
 | 
			
		||||
    } else {
 | 
			
		||||
        cgraph->grads[isrc] = tensor;
 | 
			
		||||
    }
 | 
			
		||||
    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
 | 
			
		||||
    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -5040,18 +5045,20 @@ static void ggml_acc_or_set(
 | 
			
		||||
        struct ggml_context * ctx,
 | 
			
		||||
        struct ggml_cgraph  * cgraph,
 | 
			
		||||
        size_t                isrc,
 | 
			
		||||
        struct ggml_tensor  * src,
 | 
			
		||||
        struct ggml_tensor  * tensor,
 | 
			
		||||
        const  size_t         nb1,
 | 
			
		||||
        const  size_t         nb2,
 | 
			
		||||
        const  size_t         nb3,
 | 
			
		||||
        const  size_t         offset) {
 | 
			
		||||
    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
 | 
			
		||||
    GGML_ASSERT(src);
 | 
			
		||||
    if (cgraph->grads[isrc]) {
 | 
			
		||||
        cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]);
 | 
			
		||||
    } else {
 | 
			
		||||
        struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
 | 
			
		||||
        cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false);
 | 
			
		||||
    }
 | 
			
		||||
    ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name);
 | 
			
		||||
    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -5059,13 +5066,15 @@ static void ggml_add1_or_set(
 | 
			
		||||
        struct ggml_context * ctx,
 | 
			
		||||
        struct ggml_cgraph  * cgraph,
 | 
			
		||||
        size_t                isrc,
 | 
			
		||||
        struct ggml_tensor  * src,
 | 
			
		||||
        struct ggml_tensor  * tensor) {
 | 
			
		||||
    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
 | 
			
		||||
    GGML_ASSERT(src);
 | 
			
		||||
    if (cgraph->grads[isrc]) {
 | 
			
		||||
        cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
 | 
			
		||||
    } else {
 | 
			
		||||
        cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src);
 | 
			
		||||
    }
 | 
			
		||||
    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
 | 
			
		||||
    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -5074,11 +5083,14 @@ static void ggml_sub_or_set(
 | 
			
		||||
        struct ggml_cgraph  * cgraph,
 | 
			
		||||
        size_t                isrc,
 | 
			
		||||
        struct ggml_tensor  * tensor) {
 | 
			
		||||
    struct ggml_tensor * src = cgraph->visited_hash_set.keys[isrc];
 | 
			
		||||
    GGML_ASSERT(src);
 | 
			
		||||
    if (cgraph->grads[isrc]) {
 | 
			
		||||
        cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]);
 | 
			
		||||
    } else {
 | 
			
		||||
        cgraph->grads[isrc] = ggml_neg(ctx, tensor);
 | 
			
		||||
    }
 | 
			
		||||
    ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name);
 | 
			
		||||
    ggml_build_forward_expand(cgraph, cgraph->grads[isrc]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -5095,12 +5107,12 @@ static void ggml_compute_backward(
 | 
			
		||||
    struct ggml_tensor * src1 = tensor->src[1];
 | 
			
		||||
    struct ggml_tensor * src2 = tensor->src[2];
 | 
			
		||||
    struct ggml_hash_set * hash_set = &cgraph->visited_hash_set;
 | 
			
		||||
    const size_t isrc0 = ggml_hash_find(hash_set, src0);
 | 
			
		||||
    const size_t isrc1 = ggml_hash_find(hash_set, src1);
 | 
			
		||||
    const size_t isrc2 = ggml_hash_find(hash_set, src2);
 | 
			
		||||
    const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
 | 
			
		||||
    const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
 | 
			
		||||
    const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
 | 
			
		||||
    const size_t isrc0 = src0 ? ggml_hash_find(hash_set, src0) : (size_t) -1;
 | 
			
		||||
    const size_t isrc1 = src1 ? ggml_hash_find(hash_set, src1) : (size_t) -1;
 | 
			
		||||
    const size_t isrc2 = src2 ? ggml_hash_find(hash_set, src2) : (size_t) -1;
 | 
			
		||||
    const bool src0_needs_grads = src0 && isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0];
 | 
			
		||||
    const bool src1_needs_grads = src1 && isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1];
 | 
			
		||||
    const bool src2_needs_grads = src2 && isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2];
 | 
			
		||||
 | 
			
		||||
    switch (tensor->op) {
 | 
			
		||||
        case GGML_OP_DUP: {
 | 
			
		||||
@@ -5200,7 +5212,7 @@ static void ggml_compute_backward(
 | 
			
		||||
        } break;
 | 
			
		||||
        case GGML_OP_SUM: {
 | 
			
		||||
            if (src0_needs_grads) {
 | 
			
		||||
                ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad);
 | 
			
		||||
                ggml_add1_or_set(ctx, cgraph, isrc0, grad);
 | 
			
		||||
            }
 | 
			
		||||
        } break;
 | 
			
		||||
        case GGML_OP_SUM_ROWS: {
 | 
			
		||||
@@ -5210,7 +5222,7 @@ static void ggml_compute_backward(
 | 
			
		||||
        } break;
 | 
			
		||||
        case GGML_OP_MEAN: {
 | 
			
		||||
            if (src0_needs_grads) {
 | 
			
		||||
                ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
 | 
			
		||||
                ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
 | 
			
		||||
            }
 | 
			
		||||
        } break;
 | 
			
		||||
        case GGML_OP_REPEAT: {
 | 
			
		||||
@@ -5363,7 +5375,7 @@ static void ggml_compute_backward(
 | 
			
		||||
                    nb3 = (nb3 / n0) * ng;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset);
 | 
			
		||||
                ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset);
 | 
			
		||||
            }
 | 
			
		||||
        } break;
 | 
			
		||||
        case GGML_OP_PERMUTE: {
 | 
			
		||||
@@ -5597,10 +5609,9 @@ void ggml_build_backward_expand(
 | 
			
		||||
 | 
			
		||||
    const int n_nodes_f = cgraph->n_nodes;
 | 
			
		||||
 | 
			
		||||
    const size_t hash_size = ggml_hash_size(2*cgraph->size);
 | 
			
		||||
    memset(cgraph->grads,     0, hash_size*sizeof(struct ggml_tensor *));
 | 
			
		||||
    memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
 | 
			
		||||
    bool * grads_needed = calloc(hash_size, sizeof(bool));
 | 
			
		||||
    memset(cgraph->grads,     0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
 | 
			
		||||
    memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct ggml_tensor *));
 | 
			
		||||
    bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool));
 | 
			
		||||
 | 
			
		||||
    {
 | 
			
		||||
        bool any_params = false;
 | 
			
		||||
@@ -5621,7 +5632,7 @@ void ggml_build_backward_expand(
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        bool node_needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM;
 | 
			
		||||
        bool node_needs_grad = (node->flags & GGML_TENSOR_FLAG_PARAM) || (node->flags & GGML_TENSOR_FLAG_LOSS);
 | 
			
		||||
        bool ignore_src[GGML_MAX_SRC] = {false};
 | 
			
		||||
        switch (node->op) {
 | 
			
		||||
            // gradients in node->src[0] for one reason or another have no effect on output gradients
 | 
			
		||||
@@ -5638,7 +5649,7 @@ void ggml_build_backward_expand(
 | 
			
		||||
            } break;
 | 
			
		||||
 | 
			
		||||
            // gradients in node->src[1] for one reason or another have no effect on output gradients
 | 
			
		||||
            case GGML_OP_CPY:           // gradients in CPY target  are irrelevant
 | 
			
		||||
            case GGML_OP_CPY:           // gradients in CPY target are irrelevant
 | 
			
		||||
            case GGML_OP_GET_ROWS:      // row indices not differentiable
 | 
			
		||||
            case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
 | 
			
		||||
            case GGML_OP_ROPE:          // positions not differentiable
 | 
			
		||||
@@ -5665,9 +5676,12 @@ void ggml_build_backward_expand(
 | 
			
		||||
            node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE);
 | 
			
		||||
 | 
			
		||||
        const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node);
 | 
			
		||||
        GGML_ASSERT(igrad != GGML_HASHSET_FULL);
 | 
			
		||||
        GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad));
 | 
			
		||||
        if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) {
 | 
			
		||||
            cgraph->grads[igrad]     = ggml_dup_tensor(ctx_static, node);
 | 
			
		||||
            cgraph->grad_accs[igrad] = cgraph->grads[igrad];
 | 
			
		||||
            cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node);
 | 
			
		||||
            cgraph->grads[igrad]     = cgraph->grad_accs[igrad];
 | 
			
		||||
            ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name);
 | 
			
		||||
        }
 | 
			
		||||
        grads_needed[igrad] = true;
 | 
			
		||||
    }
 | 
			
		||||
@@ -5761,15 +5775,15 @@ struct ggml_cgraph * ggml_new_graph(struct ggml_context * ctx) {
 | 
			
		||||
 | 
			
		||||
struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) {
 | 
			
		||||
    struct ggml_cgraph cgraph = {
 | 
			
		||||
        /*.size         =*/ 0,
 | 
			
		||||
        /*.n_nodes      =*/ i1 - i0,
 | 
			
		||||
        /*.n_leafs      =*/ 0,
 | 
			
		||||
        /*.nodes        =*/ cgraph0->nodes + i0,
 | 
			
		||||
        /*.grads        =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL,
 | 
			
		||||
        /*.grad_accs    =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL,
 | 
			
		||||
        /*.leafs        =*/ NULL,
 | 
			
		||||
        /*.hash_table   =*/ { 0, NULL, NULL },
 | 
			
		||||
        /*.order        =*/ cgraph0->order,
 | 
			
		||||
        /*.size             =*/ 0,
 | 
			
		||||
        /*.n_nodes          =*/ i1 - i0,
 | 
			
		||||
        /*.n_leafs          =*/ 0,
 | 
			
		||||
        /*.nodes            =*/ cgraph0->nodes + i0,
 | 
			
		||||
        /*.grads            =*/ NULL, // gradients would need visited_hash_set
 | 
			
		||||
        /*.grad_accs        =*/ NULL,
 | 
			
		||||
        /*.leafs            =*/ NULL,
 | 
			
		||||
        /*.visited_hash_set =*/ { 0, NULL, NULL },
 | 
			
		||||
        /*.order            =*/ cgraph0->order,
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    return cgraph;
 | 
			
		||||
@@ -5799,12 +5813,22 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (dst->grads) {
 | 
			
		||||
        memset(dst->grads,     0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
 | 
			
		||||
        memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct ggml_tensor *));
 | 
			
		||||
    }
 | 
			
		||||
    if (src->grads) {
 | 
			
		||||
        GGML_ASSERT(dst->grads     != NULL);
 | 
			
		||||
        GGML_ASSERT(dst->grad_accs != NULL);
 | 
			
		||||
        for (int i = 0; i < src->n_nodes; ++i) {
 | 
			
		||||
            const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]);
 | 
			
		||||
            const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]);
 | 
			
		||||
 | 
			
		||||
            GGML_ASSERT(igrad_src != GGML_HASHSET_FULL);
 | 
			
		||||
            GGML_ASSERT(ggml_bitset_get(src->visited_hash_set.used, igrad_src));
 | 
			
		||||
            GGML_ASSERT(igrad_dst != GGML_HASHSET_FULL);
 | 
			
		||||
            GGML_ASSERT(ggml_bitset_get(dst->visited_hash_set.used, igrad_dst));
 | 
			
		||||
 | 
			
		||||
            dst->grads[igrad_dst]     = src->grads[igrad_src];
 | 
			
		||||
            dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src];
 | 
			
		||||
        }
 | 
			
		||||
@@ -5839,12 +5863,8 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
 | 
			
		||||
 | 
			
		||||
        if (node->op == GGML_OP_OPT_STEP_ADAMW) {
 | 
			
		||||
            // clear momenta
 | 
			
		||||
            if (node->src[2]->data) {
 | 
			
		||||
                ggml_set_zero(node->src[2]);
 | 
			
		||||
            }
 | 
			
		||||
            if (node->src[3]->data) {
 | 
			
		||||
                ggml_set_zero(node->src[3]);
 | 
			
		||||
            }
 | 
			
		||||
            ggml_set_zero(node->src[2]);
 | 
			
		||||
            ggml_set_zero(node->src[3]);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // initial gradients of loss should be 1, 0 otherwise
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user