mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	SYCL : support non-contiguous tensors in binary ops (add, sub, etc) (#12399)
* sycl : support non-contiguous tensors in binary ops * sycl : silence unused variable warning --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
		| @@ -474,6 +474,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, | |||||||
|         int ne0, int ne1, int ne2, int ne3, |         int ne0, int ne1, int ne2, int ne3, | ||||||
|         int ne10, int ne11, int ne12, int ne13, |         int ne10, int ne11, int ne12, int ne13, | ||||||
|         /*int s0, */ int s1,  int s2,  int s3, |         /*int s0, */ int s1,  int s2,  int s3, | ||||||
|  |         /*int s00,*/ int s01, int s02, int s03, | ||||||
|         /*int s10,*/ int s11, int s12, int s13, |         /*int s10,*/ int s11, int s12, int s13, | ||||||
|         const sycl::nd_item<3> &item_ct1) { |         const sycl::nd_item<3> &item_ct1) { | ||||||
|     const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + |     const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + | ||||||
| @@ -495,9 +496,9 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, | |||||||
|     const int i12 = i2 % ne12; |     const int i12 = i2 % ne12; | ||||||
|     const int i13 = i3 % ne13; |     const int i13 = i3 % ne13; | ||||||
|  |  | ||||||
|     const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; |     const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01; | ||||||
|     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; |     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; | ||||||
|     const size_t i_dst  = i_src0; |     const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1; | ||||||
|  |  | ||||||
|     const src0_t * src0_row = src0 + i_src0; |     const src0_t * src0_row = src0 + i_src0; | ||||||
|     const src1_t * src1_row = src1 + i_src1; |     const src1_t * src1_row = src1 + i_src1; | ||||||
| @@ -515,6 +516,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t | |||||||
|         int ne0, int ne1, int ne2, int ne3, |         int ne0, int ne1, int ne2, int ne3, | ||||||
|         int ne10, int ne11, int ne12, int ne13, |         int ne10, int ne11, int ne12, int ne13, | ||||||
|         /*int s0, */ int s1,  int s2,  int s3, |         /*int s0, */ int s1,  int s2,  int s3, | ||||||
|  |         /*int s00,*/ int s01, int s02, int s03, | ||||||
|         /*int s10,*/ int s11, int s12, int s13, |         /*int s10,*/ int s11, int s12, int s13, | ||||||
|         const sycl::nd_item<3> &item_ct1) { |         const sycl::nd_item<3> &item_ct1) { | ||||||
|  |  | ||||||
| @@ -534,9 +536,9 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t | |||||||
|     const int i12 = i2 % ne12; |     const int i12 = i2 % ne12; | ||||||
|     const int i13 = i3 % ne13; |     const int i13 = i3 % ne13; | ||||||
|  |  | ||||||
|     const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; |     const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01; | ||||||
|     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; |     const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; | ||||||
|     const size_t i_dst  = i_src0; |     const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1; | ||||||
|  |  | ||||||
|     const src0_t * src0_row = src0 + i_src0; |     const src0_t * src0_row = src0 + i_src0; | ||||||
|     const src1_t * src1_row = src1 + i_src1; |     const src1_t * src1_row = src1 + i_src1; | ||||||
| @@ -566,9 +568,11 @@ struct bin_bcast_sycl { | |||||||
|         int nr[4] = { nr0, nr1, nr2, nr3 }; |         int nr[4] = { nr0, nr1, nr2, nr3 }; | ||||||
|  |  | ||||||
|         // collapse dimensions until first broadcast dimension |         // collapse dimensions until first broadcast dimension | ||||||
|         int64_t cne0[] = {ne0, ne1, ne2, ne3}; |         int64_t cne[] = {ne0, ne1, ne2, ne3}; | ||||||
|  |         int64_t cne0[] = {ne00, ne01, ne02, ne03}; | ||||||
|         int64_t cne1[] = {ne10, ne11, ne12, ne13}; |         int64_t cne1[] = {ne10, ne11, ne12, ne13}; | ||||||
|         size_t cnb0[] = {nb0, nb1, nb2, nb3}; |         size_t cnb[] = {nb0, nb1, nb2, nb3}; | ||||||
|  |         size_t cnb0[] = {nb00, nb01, nb02, nb03}; | ||||||
|         size_t cnb1[] = {nb10, nb11, nb12, nb13}; |         size_t cnb1[] = {nb10, nb11, nb12, nb13}; | ||||||
|         auto collapse = [](int64_t cne[]) { |         auto collapse = [](int64_t cne[]) { | ||||||
|             cne[0] *= cne[1]; |             cne[0] *= cne[1]; | ||||||
| @@ -583,32 +587,41 @@ struct bin_bcast_sycl { | |||||||
|             cnb[3] *= cne[3]; |             cnb[3] *= cne[3]; | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         for (int i = 0; i < 4; i++) { |         if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { | ||||||
|             if (nr[i] != 1) { |             for (int i = 0; i < 4; i++) { | ||||||
|                 break; |                 if (nr[i] != 1) { | ||||||
|             } |                     break; | ||||||
|             if (i > 0) { |                 } | ||||||
|                 collapse_nb(cnb0, cne0); |                 if (i > 0) { | ||||||
|                 collapse_nb(cnb1, cne1); |                     collapse_nb(cnb, cne); | ||||||
|                 collapse(cne0); |                     collapse_nb(cnb0, cne0); | ||||||
|                 collapse(cne1); |                     collapse_nb(cnb1, cne1); | ||||||
|  |                     collapse(cne); | ||||||
|  |                     collapse(cne0); | ||||||
|  |                     collapse(cne1); | ||||||
|  |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|         { |         { | ||||||
|             int64_t ne0 = cne0[0]; |             int64_t ne0 = cne[0]; | ||||||
|             int64_t ne1 = cne0[1]; |             int64_t ne1 = cne[1]; | ||||||
|             int64_t ne2 = cne0[2]; |             int64_t ne2 = cne[2]; | ||||||
|             int64_t ne3 = cne0[3]; |             int64_t ne3 = cne[3]; | ||||||
|  |  | ||||||
|             int64_t ne10 = cne1[0]; |             int64_t ne10 = cne1[0]; | ||||||
|             int64_t ne11 = cne1[1]; |             int64_t ne11 = cne1[1]; | ||||||
|             int64_t ne12 = cne1[2]; |             int64_t ne12 = cne1[2]; | ||||||
|             int64_t ne13 = cne1[3]; |             int64_t ne13 = cne1[3]; | ||||||
|  |  | ||||||
|             size_t nb0 = cnb0[0]; |             size_t nb0 = cnb[0]; | ||||||
|             size_t nb1 = cnb0[1]; |             size_t nb1 = cnb[1]; | ||||||
|             size_t nb2 = cnb0[2]; |             size_t nb2 = cnb[2]; | ||||||
|             size_t nb3 = cnb0[3]; |             size_t nb3 = cnb[3]; | ||||||
|  |  | ||||||
|  |             size_t nb00 = cnb0[0]; | ||||||
|  |             size_t nb01 = cnb0[1]; | ||||||
|  |             size_t nb02 = cnb0[2]; | ||||||
|  |             size_t nb03 = cnb0[3]; | ||||||
|  |  | ||||||
|             size_t nb10 = cnb1[0]; |             size_t nb10 = cnb1[0]; | ||||||
|             size_t nb11 = cnb1[1]; |             size_t nb11 = cnb1[1]; | ||||||
| @@ -625,6 +638,28 @@ struct bin_bcast_sycl { | |||||||
|             size_t s12 = nb12 / sizeof(src1_t); |             size_t s12 = nb12 / sizeof(src1_t); | ||||||
|             size_t s13 = nb13 / sizeof(src1_t); |             size_t s13 = nb13 / sizeof(src1_t); | ||||||
|  |  | ||||||
|  |             size_t s00 = nb00 / sizeof(src0_t); | ||||||
|  |             size_t s01 = nb01 / sizeof(src0_t); | ||||||
|  |             size_t s02 = nb02 / sizeof(src0_t); | ||||||
|  |             size_t s03 = nb03 / sizeof(src0_t); | ||||||
|  |  | ||||||
|  |             GGML_UNUSED(s00); | ||||||
|  |  | ||||||
|  |             GGML_ASSERT(nb0 % sizeof(dst_t) == 0); | ||||||
|  |             GGML_ASSERT(nb1 % sizeof(dst_t) == 0); | ||||||
|  |             GGML_ASSERT(nb2 % sizeof(dst_t) == 0); | ||||||
|  |             GGML_ASSERT(nb3 % sizeof(dst_t) == 0); | ||||||
|  |  | ||||||
|  |             GGML_ASSERT(nb00 % sizeof(src0_t) == 0); | ||||||
|  |             GGML_ASSERT(nb01 % sizeof(src0_t) == 0); | ||||||
|  |             GGML_ASSERT(nb02 % sizeof(src0_t) == 0); | ||||||
|  |             GGML_ASSERT(nb03 % sizeof(src0_t) == 0); | ||||||
|  |  | ||||||
|  |             GGML_ASSERT(nb10 % sizeof(src1_t) == 0); | ||||||
|  |             GGML_ASSERT(nb11 % sizeof(src1_t) == 0); | ||||||
|  |             GGML_ASSERT(nb12 % sizeof(src1_t) == 0); | ||||||
|  |             GGML_ASSERT(nb13 % sizeof(src1_t) == 0); | ||||||
|  |  | ||||||
|             GGML_ASSERT(s0 == 1); |             GGML_ASSERT(s0 == 1); | ||||||
|             GGML_ASSERT(s10 == 1); |             GGML_ASSERT(s10 == 1); | ||||||
|  |  | ||||||
| @@ -661,8 +696,8 @@ struct bin_bcast_sycl { | |||||||
|                         [=](sycl::nd_item<3> item_ct1) { |                         [=](sycl::nd_item<3> item_ct1) { | ||||||
|                             k_bin_bcast_unravel<bin_op>( |                             k_bin_bcast_unravel<bin_op>( | ||||||
|                                 src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, |                                 src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, | ||||||
|                                 ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12, |                                 ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, | ||||||
|                                 s13, item_ct1); |                                 s03, s11, s12, s13, item_ct1); | ||||||
|                         }); |                         }); | ||||||
|                 } |                 } | ||||||
|             } else { |             } else { | ||||||
| @@ -680,7 +715,7 @@ struct bin_bcast_sycl { | |||||||
|                     [=](sycl::nd_item<3> item_ct1) { |                     [=](sycl::nd_item<3> item_ct1) { | ||||||
|                         k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, |                         k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, | ||||||
|                                             ne2, ne3, ne10, ne11, ne12, ne13, |                                             ne2, ne3, ne10, ne11, ne12, ne13, | ||||||
|                                             s1, s2, s3, s11, s12, s13, |                                             s1, s2, s3, s01, s02, s03, s11, s12, s13, | ||||||
|                                             item_ct1); |                                             item_ct1); | ||||||
|                     }); |                     }); | ||||||
|             } |             } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 fairydreaming
					fairydreaming