mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	mpi : trying to move more MPI stuff into ggml-mpi (WIP) (#2099)
This commit is contained in:
		@@ -34,7 +34,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) {
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
    fprintf(stderr, "%s: seed  = %d\n", __func__, params.seed);
 | 
					    fprintf(stderr, "%s: seed  = %d\n", __func__, params.seed);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_init_backend(params.numa);
 | 
					    llama_backend_init(params.numa);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_model * model;
 | 
					    llama_model * model;
 | 
				
			||||||
    llama_context * ctx;
 | 
					    llama_context * ctx;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -35,7 +35,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        params.prompt = gpt_random_prompt(rng);
 | 
					        params.prompt = gpt_random_prompt(rng);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_init_backend(params.numa);
 | 
					    llama_backend_init(params.numa);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_model * model;
 | 
					    llama_model * model;
 | 
				
			||||||
    llama_context * ctx;
 | 
					    llama_context * ctx;
 | 
				
			||||||
@@ -93,5 +93,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
    llama_free(ctx);
 | 
					    llama_free(ctx);
 | 
				
			||||||
    llama_free_model(model);
 | 
					    llama_free_model(model);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    llama_backend_free();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return 0;
 | 
					    return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -105,7 +105,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        params.prompt = gpt_random_prompt(rng);
 | 
					        params.prompt = gpt_random_prompt(rng);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_init_backend(params.numa);
 | 
					    llama_backend_init(params.numa);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_model * model;
 | 
					    llama_model * model;
 | 
				
			||||||
    llama_context * ctx;
 | 
					    llama_context * ctx;
 | 
				
			||||||
@@ -671,7 +671,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
    llama_free(ctx);
 | 
					    llama_free(ctx);
 | 
				
			||||||
    llama_free_model(model);
 | 
					    llama_free_model(model);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_finalize_backend();
 | 
					    llama_backend_free();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return 0;
 | 
					    return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -147,7 +147,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        params.prompt = gpt_random_prompt(rng);
 | 
					        params.prompt = gpt_random_prompt(rng);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_init_backend(params.numa);
 | 
					    llama_backend_init(params.numa);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_model * model;
 | 
					    llama_model * model;
 | 
				
			||||||
    llama_context * ctx;
 | 
					    llama_context * ctx;
 | 
				
			||||||
@@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
    llama_free(ctx);
 | 
					    llama_free(ctx);
 | 
				
			||||||
    llama_free_model(model);
 | 
					    llama_free_model(model);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_finalize_backend();
 | 
					    llama_backend_free();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return 0;
 | 
					    return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        usage(argv[0]);
 | 
					        usage(argv[0]);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_init_backend(false);
 | 
					    llama_backend_init(false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // parse command line arguments
 | 
					    // parse command line arguments
 | 
				
			||||||
    const std::string fname_inp = argv[arg_idx];
 | 
					    const std::string fname_inp = argv[arg_idx];
 | 
				
			||||||
@@ -257,5 +257,7 @@ int main(int argc, char ** argv) {
 | 
				
			|||||||
        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
 | 
					        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    llama_backend_free();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return 0;
 | 
					    return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1079,7 +1079,7 @@ int main(int argc, char **argv)
 | 
				
			|||||||
        params.model_alias = params.model;
 | 
					        params.model_alias = params.model;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_init_backend(params.numa);
 | 
					    llama_backend_init(params.numa);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    LOG_INFO("build info", {{"build", BUILD_NUMBER},
 | 
					    LOG_INFO("build info", {{"build", BUILD_NUMBER},
 | 
				
			||||||
                            {"commit", BUILD_COMMIT}});
 | 
					                            {"commit", BUILD_COMMIT}});
 | 
				
			||||||
@@ -1309,5 +1309,7 @@ int main(int argc, char **argv)
 | 
				
			|||||||
        return 1;
 | 
					        return 1;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    llama_backend_free();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return 0;
 | 
					    return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -66,7 +66,7 @@ int main(int argc, char ** argv)
 | 
				
			|||||||
    // Init LLM :
 | 
					    // Init LLM :
 | 
				
			||||||
    //---------------------------------
 | 
					    //---------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_init_backend(params.numa);
 | 
					    llama_backend_init(params.numa);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_model * model;
 | 
					    llama_model * model;
 | 
				
			||||||
    llama_context * ctx;
 | 
					    llama_context * ctx;
 | 
				
			||||||
@@ -173,7 +173,7 @@ int main(int argc, char ** argv)
 | 
				
			|||||||
    llama_free( ctx );
 | 
					    llama_free( ctx );
 | 
				
			||||||
    llama_free_model( model );
 | 
					    llama_free_model( model );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    llama_finalize_backend();
 | 
					    llama_backend_free();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return 0;
 | 
					    return 0;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										66
									
								
								ggml-mpi.c
									
									
									
									
									
								
							
							
						
						
									
										66
									
								
								ggml-mpi.c
									
									
									
									
									
								
							@@ -2,9 +2,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
#include "ggml.h"
 | 
					#include "ggml.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <mpi.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include <stdio.h>
 | 
					#include <stdio.h>
 | 
				
			||||||
#include <stdlib.h>
 | 
					#include <stdlib.h>
 | 
				
			||||||
#include <mpi.h>
 | 
					
 | 
				
			||||||
#define UNUSED GGML_UNUSED
 | 
					#define UNUSED GGML_UNUSED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct ggml_mpi_tensor_info {
 | 
					struct ggml_mpi_tensor_info {
 | 
				
			||||||
@@ -52,9 +54,8 @@ static void ggml_mpi_compute_forward_recv(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
struct ggml_tensor * ggml_mpi_send_tensor(
 | 
					struct ggml_tensor * ggml_mpi_send_tensor(
 | 
				
			||||||
        struct ggml_context * ctx,
 | 
					        struct ggml_context * ctx,
 | 
				
			||||||
        struct ggml_tensor *src,
 | 
					         struct ggml_tensor * src,
 | 
				
			||||||
                        int   dst_rank) {
 | 
					                        int   dst_rank) {
 | 
				
			||||||
 | 
					 | 
				
			||||||
    struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send);
 | 
					    struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // TODO how/when to free this struct?
 | 
					    // TODO how/when to free this struct?
 | 
				
			||||||
@@ -67,8 +68,8 @@ struct ggml_tensor * ggml_mpi_send_tensor(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
struct ggml_tensor * ggml_mpi_recv_tensor(
 | 
					struct ggml_tensor * ggml_mpi_recv_tensor(
 | 
				
			||||||
        struct ggml_context * ctx,
 | 
					        struct ggml_context * ctx,
 | 
				
			||||||
        struct ggml_tensor *parent,
 | 
					         struct ggml_tensor * parent,
 | 
				
			||||||
        struct ggml_tensor *dst,
 | 
					         struct ggml_tensor * dst,
 | 
				
			||||||
                        int   src_rank) {
 | 
					                        int   src_rank) {
 | 
				
			||||||
    struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv);
 | 
					    struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -79,3 +80,58 @@ struct ggml_tensor * ggml_mpi_recv_tensor(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return result;
 | 
					    return result;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ggml_mpi_context {
 | 
				
			||||||
 | 
					    int mpi_rank;
 | 
				
			||||||
 | 
					    int mpi_size;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void ggml_mpi_backend_init(void) {
 | 
				
			||||||
 | 
					    MPI_Init(NULL, NULL);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void ggml_mpi_backend_free(void) {
 | 
				
			||||||
 | 
					    MPI_Finalize();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ggml_mpi_context * ggml_mpi_init(void) {
 | 
				
			||||||
 | 
					    struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank);
 | 
				
			||||||
 | 
					    MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return ctx;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void ggml_mpi_free(struct ggml_mpi_context * ctx) {
 | 
				
			||||||
 | 
					    free(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
 | 
				
			||||||
 | 
					    return ctx->mpi_rank;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ggml_tensor * ggml_mpi_eval_init(
 | 
				
			||||||
 | 
					        struct ggml_mpi_context * ctx_mpi,
 | 
				
			||||||
 | 
					        struct ggml_context     * ctx,
 | 
				
			||||||
 | 
					                            int   n_embd,
 | 
				
			||||||
 | 
					                            int * n_tokens,
 | 
				
			||||||
 | 
					                            int * n_past,
 | 
				
			||||||
 | 
					                            int * n_threads) {
 | 
				
			||||||
 | 
					    struct ggml_tensor * res = NULL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // synchronize the worker node parameters with the root node
 | 
				
			||||||
 | 
					    MPI_Barrier(MPI_COMM_WORLD);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    MPI_Bcast(n_tokens,  1, MPI_INT, 0, MPI_COMM_WORLD);
 | 
				
			||||||
 | 
					    MPI_Bcast(n_past,    1, MPI_INT, 0, MPI_COMM_WORLD);
 | 
				
			||||||
 | 
					    MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (ctx_mpi->mpi_rank > 0) {
 | 
				
			||||||
 | 
					        res = ggml_mpi_recv_tensor(ctx, NULL,
 | 
				
			||||||
 | 
					                ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, *n_tokens), ctx_mpi->mpi_rank - 1);
 | 
				
			||||||
 | 
					        ggml_set_name(res, "mpi_recv");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return res;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										24
									
								
								ggml-mpi.h
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								ggml-mpi.h
									
									
									
									
									
								
							@@ -9,14 +9,32 @@ extern "C" {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
struct ggml_tensor * ggml_mpi_send_tensor(
 | 
					struct ggml_tensor * ggml_mpi_send_tensor(
 | 
				
			||||||
        struct ggml_context * ctx,
 | 
					        struct ggml_context * ctx,
 | 
				
			||||||
        struct ggml_tensor *src,
 | 
					         struct ggml_tensor * src,
 | 
				
			||||||
                        int   dst_rank);
 | 
					                        int   dst_rank);
 | 
				
			||||||
struct ggml_tensor * ggml_mpi_recv_tensor(
 | 
					struct ggml_tensor * ggml_mpi_recv_tensor(
 | 
				
			||||||
        struct ggml_context * ctx,
 | 
					        struct ggml_context * ctx,
 | 
				
			||||||
        struct ggml_tensor *parent,
 | 
					         struct ggml_tensor * parent,
 | 
				
			||||||
        struct ggml_tensor *dst,
 | 
					         struct ggml_tensor * dst,
 | 
				
			||||||
                        int   src_rank);
 | 
					                        int   src_rank);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ggml_mpi_context;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void ggml_mpi_backend_init(void);
 | 
				
			||||||
 | 
					void ggml_mpi_backend_free(void);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ggml_mpi_context * ggml_mpi_init(void);
 | 
				
			||||||
 | 
					void ggml_mpi_free(struct ggml_mpi_context * ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int ggml_mpi_rank(struct ggml_mpi_context * ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ggml_tensor * ggml_mpi_eval_init(
 | 
				
			||||||
 | 
					        struct ggml_mpi_context * ctx_mpi,
 | 
				
			||||||
 | 
					        struct ggml_context     * ctx,
 | 
				
			||||||
 | 
					                            int   n_embd,
 | 
				
			||||||
 | 
					                            int * n_tokens,
 | 
				
			||||||
 | 
					                            int * n_past,
 | 
				
			||||||
 | 
					                            int * n_threads);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef __cplusplus
 | 
					#ifdef __cplusplus
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										73
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										73
									
								
								llama.cpp
									
									
									
									
									
								
							@@ -52,10 +52,6 @@
 | 
				
			|||||||
#include <sstream>
 | 
					#include <sstream>
 | 
				
			||||||
#include <numeric>
 | 
					#include <numeric>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef GGML_USE_MPI
 | 
					 | 
				
			||||||
#include <mpi.h>
 | 
					 | 
				
			||||||
#endif
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#if defined(_MSC_VER)
 | 
					#if defined(_MSC_VER)
 | 
				
			||||||
#pragma warning(disable: 4244 4267) // possible loss of data
 | 
					#pragma warning(disable: 4244 4267) // possible loss of data
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
@@ -337,8 +333,9 @@ struct llama_context {
 | 
				
			|||||||
    ggml_metal_context * ctx_metal = NULL;
 | 
					    ggml_metal_context * ctx_metal = NULL;
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int    mpi_rank;
 | 
					#ifdef GGML_USE_MPI
 | 
				
			||||||
    int    mpi_size;
 | 
					    ggml_mpi_context * ctx_mpi = NULL;
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int    buf_last = 0;
 | 
					    int    buf_last = 0;
 | 
				
			||||||
    size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
 | 
					    size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 };
 | 
				
			||||||
@@ -859,7 +856,7 @@ bool llama_mlock_supported() {
 | 
				
			|||||||
    return llama_mlock::SUPPORTED;
 | 
					    return llama_mlock::SUPPORTED;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void llama_init_backend(bool numa) {
 | 
					void llama_backend_init(bool numa) {
 | 
				
			||||||
    ggml_time_init();
 | 
					    ggml_time_init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // needed to initialize f16 tables
 | 
					    // needed to initialize f16 tables
 | 
				
			||||||
@@ -872,14 +869,15 @@ void llama_init_backend(bool numa) {
 | 
				
			|||||||
    if (numa) {
 | 
					    if (numa) {
 | 
				
			||||||
        ggml_numa_init();
 | 
					        ggml_numa_init();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef GGML_USE_MPI
 | 
					#ifdef GGML_USE_MPI
 | 
				
			||||||
    MPI_Init(NULL, NULL);
 | 
					    ggml_mpi_backend_init();
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void llama_finalize_backend() {
 | 
					void llama_backend_free() {
 | 
				
			||||||
#ifdef GGML_USE_MPI
 | 
					#ifdef GGML_USE_MPI
 | 
				
			||||||
    MPI_Finalize();
 | 
					    ggml_mpi_backend_free();
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -1282,9 +1280,9 @@ static bool llama_eval_internal(
 | 
				
			|||||||
         llama_context & lctx,
 | 
					         llama_context & lctx,
 | 
				
			||||||
     const llama_token * tokens,
 | 
					     const llama_token * tokens,
 | 
				
			||||||
           const float * embd,
 | 
					           const float * embd,
 | 
				
			||||||
             const int   n_tokens,
 | 
					                   int   n_tokens,
 | 
				
			||||||
             const int   n_past,
 | 
					                   int   n_past,
 | 
				
			||||||
             const int   n_threads,
 | 
					                   int   n_threads,
 | 
				
			||||||
            const char * cgraph_fname) {
 | 
					            const char * cgraph_fname) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
 | 
					    LLAMA_ASSERT((!tokens && embd) || (tokens && !embd));
 | 
				
			||||||
@@ -1333,16 +1331,14 @@ static bool llama_eval_internal(
 | 
				
			|||||||
    struct ggml_tensor * cur;
 | 
					    struct ggml_tensor * cur;
 | 
				
			||||||
    struct ggml_tensor * inpL;
 | 
					    struct ggml_tensor * inpL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (lctx.mpi_rank > 0) {
 | 
					 | 
				
			||||||
#ifdef GGML_USE_MPI
 | 
					#ifdef GGML_USE_MPI
 | 
				
			||||||
        inpL = ggml_mpi_recv_tensor(ctx0, NULL,
 | 
					    inpL = ggml_mpi_eval_init(lctx.ctx_mpi, ctx0, n_embd, &n_tokens, &n_past, &n_threads);
 | 
				
			||||||
                ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N),
 | 
					
 | 
				
			||||||
                lctx.mpi_rank-1);
 | 
					    if (inpL) {
 | 
				
			||||||
        ggml_set_name(inpL, "mpi_recv");
 | 
					        // only rank 0 loads uses the input
 | 
				
			||||||
#else
 | 
					    } else
 | 
				
			||||||
        GGML_ASSERT(false);
 | 
					 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
    } else if (tokens) {
 | 
					    if (tokens) {
 | 
				
			||||||
        struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
 | 
					        struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
 | 
				
			||||||
        ggml_set_name(embd, "embd");
 | 
					        ggml_set_name(embd, "embd");
 | 
				
			||||||
        memcpy(embd->data, tokens, N*ggml_element_size(embd));
 | 
					        memcpy(embd->data, tokens, N*ggml_element_size(embd));
 | 
				
			||||||
@@ -1585,7 +1581,6 @@ static bool llama_eval_internal(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        // input for next layer
 | 
					        // input for next layer
 | 
				
			||||||
        inpL = cur;
 | 
					        inpL = cur;
 | 
				
			||||||
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lctx.use_buf(ctx0, 0);
 | 
					    lctx.use_buf(ctx0, 0);
 | 
				
			||||||
@@ -1601,6 +1596,7 @@ static bool llama_eval_internal(
 | 
				
			|||||||
        GGML_ASSERT(false);
 | 
					        GGML_ASSERT(false);
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (lctx.mpi_rank == 0) {
 | 
					    if (lctx.mpi_rank == 0) {
 | 
				
			||||||
        if (lctx.mpi_size > 1) {
 | 
					        if (lctx.mpi_size > 1) {
 | 
				
			||||||
#ifdef GGML_USE_MPI
 | 
					#ifdef GGML_USE_MPI
 | 
				
			||||||
@@ -1688,7 +1684,11 @@ static bool llama_eval_internal(
 | 
				
			|||||||
    // update kv token count
 | 
					    // update kv token count
 | 
				
			||||||
    lctx.kv_self.n = n_past + N;
 | 
					    lctx.kv_self.n = n_past + N;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (lctx.mpi_rank == 0) {
 | 
					#ifdef GGML_USE_MPI
 | 
				
			||||||
 | 
					    if (ggml_mpi_rank(lctx.ctx_mpi) == 0) {
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
        // extract logits
 | 
					        // extract logits
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            auto & logits_out = lctx.logits;
 | 
					            auto & logits_out = lctx.logits;
 | 
				
			||||||
@@ -2659,14 +2659,6 @@ struct llama_context * llama_new_context_with_model(
 | 
				
			|||||||
    ctx->rng = std::mt19937(params.seed);
 | 
					    ctx->rng = std::mt19937(params.seed);
 | 
				
			||||||
    ctx->logits_all = params.logits_all;
 | 
					    ctx->logits_all = params.logits_all;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef GGML_USE_MPI
 | 
					 | 
				
			||||||
    MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size);
 | 
					 | 
				
			||||||
    MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank);
 | 
					 | 
				
			||||||
#else
 | 
					 | 
				
			||||||
    ctx->mpi_size = 1;
 | 
					 | 
				
			||||||
    ctx->mpi_rank = 0;
 | 
					 | 
				
			||||||
#endif
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 | 
					    ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // reserve memory for context buffers
 | 
					    // reserve memory for context buffers
 | 
				
			||||||
@@ -2739,15 +2731,17 @@ struct llama_context * llama_new_context_with_model(
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (ctx->mpi_rank > 0) {
 | 
					#ifdef GGML_USE_MPI
 | 
				
			||||||
 | 
					    ctx->ctx_mpi = ggml_mpi_init();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (ggml_mpi_rank(ctx->ctx_mpi) > 0) {
 | 
				
			||||||
        // Enter a blocking eval loop with dummy input, letting rank=0 drive the process
 | 
					        // Enter a blocking eval loop with dummy input, letting rank=0 drive the process
 | 
				
			||||||
        const std::vector<llama_token> tmp = { llama_token_bos(), };
 | 
					        const std::vector<llama_token> tmp = { llama_token_bos(), };
 | 
				
			||||||
        while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0));
 | 
					        while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {};
 | 
				
			||||||
#ifdef GGML_USE_MPI
 | 
					        llama_backend_free();
 | 
				
			||||||
        MPI_Finalize();
 | 
					 | 
				
			||||||
#endif
 | 
					 | 
				
			||||||
        exit(1);
 | 
					        exit(1);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					#endif
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return ctx;
 | 
					    return ctx;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -3425,13 +3419,6 @@ int llama_eval(
 | 
				
			|||||||
                         int   n_tokens,
 | 
					                         int   n_tokens,
 | 
				
			||||||
                         int   n_past,
 | 
					                         int   n_past,
 | 
				
			||||||
                         int   n_threads) {
 | 
					                         int   n_threads) {
 | 
				
			||||||
#ifdef GGML_USE_MPI
 | 
					 | 
				
			||||||
    // Synchronize the worker node parameters with the root node
 | 
					 | 
				
			||||||
    MPI_Barrier(MPI_COMM_WORLD);
 | 
					 | 
				
			||||||
    MPI_Bcast(&n_past, 1, MPI_INT, 0, MPI_COMM_WORLD);
 | 
					 | 
				
			||||||
    MPI_Bcast(&n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD);
 | 
					 | 
				
			||||||
    MPI_Bcast(&n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD);
 | 
					 | 
				
			||||||
#endif
 | 
					 | 
				
			||||||
    if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) {
 | 
					    if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) {
 | 
				
			||||||
        fprintf(stderr, "%s: failed to eval\n", __func__);
 | 
					        fprintf(stderr, "%s: failed to eval\n", __func__);
 | 
				
			||||||
        return 1;
 | 
					        return 1;
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										4
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								llama.h
									
									
									
									
									
								
							@@ -158,9 +158,9 @@ extern "C" {
 | 
				
			|||||||
    // Initialize the llama + ggml backend
 | 
					    // Initialize the llama + ggml backend
 | 
				
			||||||
    // If numa is true, use NUMA optimizations
 | 
					    // If numa is true, use NUMA optimizations
 | 
				
			||||||
    // Call once at the start of the program
 | 
					    // Call once at the start of the program
 | 
				
			||||||
    LLAMA_API void llama_init_backend(bool numa);
 | 
					    LLAMA_API void llama_backend_init(bool numa);
 | 
				
			||||||
    // Call once at the end of the program - currently only used for MPI
 | 
					    // Call once at the end of the program - currently only used for MPI
 | 
				
			||||||
    LLAMA_API void llama_finalize_backend();
 | 
					    LLAMA_API void llama_backend_free();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    LLAMA_API int64_t llama_time_us();
 | 
					    LLAMA_API int64_t llama_time_us();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user