mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	* CUDA eval works * stochastic gradient descent op * Adam except decay * CUDA CROSS_ENTROPY_LOSS_BACK * CUDA mnist-fc training works * backend CLI arg * refactor gguf load * remove sched from opt_step_adam * implement l1 regularization (weight decay) * extra call to add optimizer * initialize gradients with ggml_graph_reset * gradient accumulation * increment iter per eval instead of epoch * adjust backend interfaces * fix ggml_graph_reset without backend * fix ggml graph export/import * fixup * rename * revert ggml_opt changes * more general CUDA repeat_back * update documentation, fix CNN * validation split * add clarifying comment * optimize PyTorch training * adjust buffer size, thread count * fix 0.0f validation split * Update examples/mnist/mnist-common.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * fix gradient accumulation * tensor flag for accumulators -> tensor hash set * Update include/ggml.h Co-authored-by: slaren <slarengh@gmail.com> * Update tests/test-backend-ops.cpp Co-authored-by: slaren <slarengh@gmail.com> * Update tests/test-backend-ops.cpp Co-authored-by: slaren <slarengh@gmail.com> * fix test prints * Update src/ggml-backend.c Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * better CUDA support for noncontiguous out_prod * add comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
		
			
				
	
	
		
			155 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
			
		
		
	
	
			155 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			C
		
	
	
	
	
	
#pragma once
 | 
						|
 | 
						|
// ggml-backend internal header
 | 
						|
 | 
						|
#include "ggml-backend.h"
 | 
						|
 | 
						|
#ifdef  __cplusplus
 | 
						|
extern "C" {
 | 
						|
#endif
 | 
						|
 | 
						|
    //
 | 
						|
    // Backend buffer
 | 
						|
    //
 | 
						|
 | 
						|
    // buffer type
 | 
						|
    typedef void * ggml_backend_buffer_type_context_t;
 | 
						|
 | 
						|
    struct ggml_backend_buffer_type_i {
 | 
						|
        const char *          (*GGML_CALL get_name)        (ggml_backend_buffer_type_t buft);
 | 
						|
        // allocate a buffer of this type
 | 
						|
        ggml_backend_buffer_t (*GGML_CALL alloc_buffer)    (ggml_backend_buffer_type_t buft, size_t size);
 | 
						|
        // tensor alignment
 | 
						|
        size_t                (*GGML_CALL get_alignment)   (ggml_backend_buffer_type_t buft);
 | 
						|
        // max buffer size that can be allocated
 | 
						|
        size_t                (*GGML_CALL get_max_size)    (ggml_backend_buffer_type_t buft);
 | 
						|
        // data size needed to allocate the tensor, including padding
 | 
						|
        size_t                (*GGML_CALL get_alloc_size)  (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor);
 | 
						|
        // check if tensor data is in host memory
 | 
						|
        bool                  (*GGML_CALL is_host)         (ggml_backend_buffer_type_t buft);
 | 
						|
    };
 | 
						|
 | 
						|
    struct ggml_backend_buffer_type {
 | 
						|
        struct ggml_backend_buffer_type_i  iface;
 | 
						|
        ggml_backend_buffer_type_context_t context;
 | 
						|
    };
 | 
						|
 | 
						|
    // buffer
 | 
						|
    typedef void * ggml_backend_buffer_context_t;
 | 
						|
 | 
						|
    struct ggml_backend_buffer_i {
 | 
						|
        const char * (*GGML_CALL get_name)      (ggml_backend_buffer_t buffer);
 | 
						|
        void         (*GGML_CALL free_buffer)   (ggml_backend_buffer_t buffer);
 | 
						|
        void *       (*GGML_CALL get_base)      (ggml_backend_buffer_t buffer);
 | 
						|
        void         (*GGML_CALL init_tensor)   (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
 | 
						|
        void         (*GGML_CALL memset_tensor) (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor,     uint8_t value, size_t offset, size_t size);
 | 
						|
        void         (*GGML_CALL set_tensor)    (ggml_backend_buffer_t buffer,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
 | 
						|
        void         (*GGML_CALL get_tensor)    (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
 | 
						|
        bool         (*GGML_CALL cpy_tensor)    (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer
 | 
						|
        void         (*GGML_CALL clear)         (ggml_backend_buffer_t buffer, uint8_t value);
 | 
						|
        void         (*GGML_CALL reset)         (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
 | 
						|
    };
 | 
						|
 | 
						|
    struct ggml_backend_buffer {
 | 
						|
        struct ggml_backend_buffer_i  iface;
 | 
						|
        ggml_backend_buffer_type_t    buft;
 | 
						|
        ggml_backend_buffer_context_t context;
 | 
						|
        size_t size;
 | 
						|
        enum ggml_backend_buffer_usage usage;
 | 
						|
    };
 | 
						|
 | 
						|
    GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init(
 | 
						|
                   ggml_backend_buffer_type_t      buft,
 | 
						|
            struct ggml_backend_buffer_i           iface,
 | 
						|
                   ggml_backend_buffer_context_t   context,
 | 
						|
                   size_t                          size);
 | 
						|
 | 
						|
    // do not use directly, use ggml_backend_tensor_copy instead
 | 
						|
    bool ggml_backend_buffer_copy_tensor(const struct ggml_tensor * src, struct ggml_tensor * dst);
 | 
						|
 | 
						|
    // buffer that contains a collection of buffers
 | 
						|
    GGML_CALL ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer_t * buffers, size_t n_buffers);
 | 
						|
    GGML_CALL bool                  ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer);
 | 
						|
    GGML_CALL void                  ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
 | 
						|
 | 
						|
    //
 | 
						|
    // Backend
 | 
						|
    //
 | 
						|
 | 
						|
    typedef void * ggml_backend_context_t;
 | 
						|
 | 
						|
    struct ggml_backend_i {
 | 
						|
        const char * (*GGML_CALL get_name)(ggml_backend_t backend);
 | 
						|
 | 
						|
        void (*GGML_CALL free)(ggml_backend_t backend);
 | 
						|
 | 
						|
        // buffer allocation
 | 
						|
        ggml_backend_buffer_type_t (*GGML_CALL get_default_buffer_type)(ggml_backend_t backend);
 | 
						|
 | 
						|
        // (optional) asynchronous tensor data access
 | 
						|
        void (*GGML_CALL set_tensor_async)(ggml_backend_t backend,       struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
 | 
						|
        void (*GGML_CALL get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor,       void * data, size_t offset, size_t size);
 | 
						|
        bool (*GGML_CALL cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
 | 
						|
 | 
						|
        // (optional) complete all pending operations
 | 
						|
        void (*GGML_CALL synchronize)(ggml_backend_t backend);
 | 
						|
 | 
						|
        // compute graph with a plan (not used currently)
 | 
						|
        // create a new plan for a graph
 | 
						|
        ggml_backend_graph_plan_t (*GGML_CALL graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
 | 
						|
        void                      (*GGML_CALL graph_plan_free)   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
 | 
						|
        // update the plan with a new graph - this should be faster than creating a new plan when the graph has the same topology
 | 
						|
        void                      (*GGML_CALL graph_plan_update) (ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);
 | 
						|
        // compute the graph with the plan
 | 
						|
        enum ggml_status          (*GGML_CALL graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
 | 
						|
 | 
						|
        // compute graph without a plan (async)
 | 
						|
        enum ggml_status (*GGML_CALL graph_compute)     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
 | 
						|
 | 
						|
        // check if the backend can compute an operation
 | 
						|
        bool (*GGML_CALL supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
 | 
						|
 | 
						|
        // check if the backend can use tensors allocated in a buffer type
 | 
						|
        bool (*GGML_CALL supports_buft)(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
 | 
						|
 | 
						|
        // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
 | 
						|
        // these should be expensive operations with large batch sizes that may benefit from running on this backend
 | 
						|
        // even if the weight has to be copied from the CPU temporarily
 | 
						|
        bool (*GGML_CALL offload_op)(ggml_backend_t backend, const struct ggml_tensor * op);
 | 
						|
 | 
						|
        // (optional) event synchronization
 | 
						|
        // create a new event that can record events on this backend instance
 | 
						|
        ggml_backend_event_t (*GGML_CALL event_new)         (ggml_backend_t backend);
 | 
						|
        void                 (*GGML_CALL event_free)        (ggml_backend_event_t event);
 | 
						|
        // record an event on the backend instance that created it
 | 
						|
        void                 (*GGML_CALL event_record)      (ggml_backend_event_t event);
 | 
						|
        // wait for an event on on a different backend instance
 | 
						|
        void                 (*GGML_CALL event_wait)        (ggml_backend_t backend, ggml_backend_event_t event);
 | 
						|
        // block until an event is recorded
 | 
						|
        void                 (*GGML_CALL event_synchronize) (ggml_backend_event_t event);
 | 
						|
    };
 | 
						|
 | 
						|
    struct ggml_backend {
 | 
						|
        ggml_guid_t guid;
 | 
						|
 | 
						|
        struct ggml_backend_i iface;
 | 
						|
        ggml_backend_context_t context;
 | 
						|
    };
 | 
						|
 | 
						|
    struct ggml_backend_event {
 | 
						|
        ggml_backend_t backend;
 | 
						|
        void * context;
 | 
						|
    };
 | 
						|
 | 
						|
    //
 | 
						|
    // Backend registry
 | 
						|
    //
 | 
						|
 | 
						|
    typedef ggml_backend_t (*GGML_CALL ggml_backend_init_fn)(const char * params, void * user_data);
 | 
						|
 | 
						|
    GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data);
 | 
						|
 | 
						|
#ifdef  __cplusplus
 | 
						|
}
 | 
						|
#endif
 |