mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	sgemm : improved Q4_0 and Q8_0 performance via 4xN and Mx4 gemm (#8908)
This commit is contained in:
		| @@ -606,17 +606,29 @@ class tinyBLAS_Q0_AVX { | |||||||
|         case 0x44: |         case 0x44: | ||||||
|             mc = 4; |             mc = 4; | ||||||
|             nc = 4; |             nc = 4; | ||||||
|  | #if defined(__AVX2__) && defined(__F16C__) | ||||||
|  |             gemm4xN<4>(m0, m, n0, n); | ||||||
|  | #else | ||||||
|             gemm<4, 4>(m0, m, n0, n); |             gemm<4, 4>(m0, m, n0, n); | ||||||
|  | #endif | ||||||
|             break; |             break; | ||||||
|         case 0x43: |         case 0x43: | ||||||
|             mc = 4; |             mc = 4; | ||||||
|             nc = 3; |             nc = 3; | ||||||
|  | #if defined(__AVX2__) && defined(__F16C__) | ||||||
|  |             gemm4xN<3>(m0, m, n0, n); | ||||||
|  | #else | ||||||
|             gemm<4, 3>(m0, m, n0, n); |             gemm<4, 3>(m0, m, n0, n); | ||||||
|  | #endif | ||||||
|             break; |             break; | ||||||
|         case 0x34: |         case 0x34: | ||||||
|             mc = 3; |             mc = 3; | ||||||
|             nc = 4; |             nc = 4; | ||||||
|  | #if defined(__AVX2__) && defined(__F16C__) | ||||||
|  |             gemmMx4<3>(m0, m, n0, n); | ||||||
|  | #else | ||||||
|             gemm<3, 4>(m0, m, n0, n); |             gemm<3, 4>(m0, m, n0, n); | ||||||
|  | #endif | ||||||
|             break; |             break; | ||||||
|         case 0x33: |         case 0x33: | ||||||
|             mc = 3; |             mc = 3; | ||||||
| @@ -626,12 +638,20 @@ class tinyBLAS_Q0_AVX { | |||||||
|         case 0x42: |         case 0x42: | ||||||
|             mc = 4; |             mc = 4; | ||||||
|             nc = 2; |             nc = 2; | ||||||
|  | #if defined(__AVX2__) && defined(__F16C__) | ||||||
|  |             gemm4xN<2>(m0, m, n0, n); | ||||||
|  | #else | ||||||
|             gemm<4, 2>(m0, m, n0, n); |             gemm<4, 2>(m0, m, n0, n); | ||||||
|  | #endif | ||||||
|             break; |             break; | ||||||
|         case 0x24: |         case 0x24: | ||||||
|             mc = 2; |             mc = 2; | ||||||
|             nc = 4; |             nc = 4; | ||||||
|  | #if defined(__AVX2__) && defined(__F16C__) | ||||||
|  |             gemmMx4<2>(m0, m, n0, n); | ||||||
|  | #else | ||||||
|             gemm<2, 4>(m0, m, n0, n); |             gemm<2, 4>(m0, m, n0, n); | ||||||
|  | #endif | ||||||
|             break; |             break; | ||||||
| #else | #else | ||||||
|         case 0x44: |         case 0x44: | ||||||
| @@ -639,13 +659,21 @@ class tinyBLAS_Q0_AVX { | |||||||
|         case 0x42: |         case 0x42: | ||||||
|             mc = 4; |             mc = 4; | ||||||
|             nc = 2; |             nc = 2; | ||||||
|  | #if defined(__AVX2__) && defined(__F16C__) | ||||||
|  |             gemm4xN<2>(m0, m, n0, n); | ||||||
|  | #else | ||||||
|             gemm<4, 2>(m0, m, n0, n); |             gemm<4, 2>(m0, m, n0, n); | ||||||
|  | #endif | ||||||
|             break; |             break; | ||||||
|         case 0x34: |         case 0x34: | ||||||
|         case 0x24: |         case 0x24: | ||||||
|             mc = 2; |             mc = 2; | ||||||
|             nc = 4; |             nc = 4; | ||||||
|  | #if defined(__AVX2__) && defined(__F16C__) | ||||||
|  |             gemmMx4<2>(m0, m, n0, n); | ||||||
|  | #else | ||||||
|             gemm<2, 4>(m0, m, n0, n); |             gemm<2, 4>(m0, m, n0, n); | ||||||
|  | #endif | ||||||
|             break; |             break; | ||||||
|         case 0x33: |         case 0x33: | ||||||
| #endif | #endif | ||||||
| @@ -662,7 +690,11 @@ class tinyBLAS_Q0_AVX { | |||||||
|         case 0x41: |         case 0x41: | ||||||
|             mc = 4; |             mc = 4; | ||||||
|             nc = 1; |             nc = 1; | ||||||
|  | #if defined(__AVX2__) && defined(__F16C__) | ||||||
|  |             gemm4xN<1>(m0, m, n0, n); | ||||||
|  | #else | ||||||
|             gemm<4, 1>(m0, m, n0, n); |             gemm<4, 1>(m0, m, n0, n); | ||||||
|  | #endif | ||||||
|             break; |             break; | ||||||
|         case 0x22: |         case 0x22: | ||||||
|             mc = 2; |             mc = 2; | ||||||
| @@ -672,7 +704,11 @@ class tinyBLAS_Q0_AVX { | |||||||
|         case 0x14: |         case 0x14: | ||||||
|             mc = 1; |             mc = 1; | ||||||
|             nc = 4; |             nc = 4; | ||||||
|  | #if defined(__AVX2__) && defined(__F16C__) | ||||||
|  |             gemmMx4<1>(m0, m, n0, n); | ||||||
|  | #else | ||||||
|             gemm<1, 4>(m0, m, n0, n); |             gemm<1, 4>(m0, m, n0, n); | ||||||
|  | #endif | ||||||
|             break; |             break; | ||||||
|         case 0x31: |         case 0x31: | ||||||
|             mc = 3; |             mc = 3; | ||||||
| @@ -708,6 +744,119 @@ class tinyBLAS_Q0_AVX { | |||||||
|         mnpack(m0, m, np, n); |         mnpack(m0, m, np, n); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  | #if defined(__AVX2__) && defined(__F16C__) | ||||||
|  | // Templated functions for gemm of dimensions 4xN | ||||||
|  |     template <int RN> | ||||||
|  |     NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) { | ||||||
|  |         int64_t ytiles = (m - m0) / 4; | ||||||
|  |         int64_t xtiles = (n - n0) / RN; | ||||||
|  |         int64_t tiles = xtiles * ytiles; | ||||||
|  |         int64_t duty = (tiles + nth - 1) / nth; | ||||||
|  |         int64_t start = duty * ith; | ||||||
|  |         int64_t end = start + duty; | ||||||
|  |         if (end > tiles) | ||||||
|  |             end = tiles; | ||||||
|  |         for (int64_t job = start; job < end; ++job) { | ||||||
|  |             int64_t ii = m0 + job / xtiles * 4; | ||||||
|  |             int64_t jj = n0 + job % xtiles * RN; | ||||||
|  |             __m256 Cv[RN][4] = {}; | ||||||
|  |             for (int64_t l = 0; l < k; ++l) { | ||||||
|  |                 uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d); | ||||||
|  |                 // Convert delta values for four blocks to float values | ||||||
|  |                 __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta)); | ||||||
|  |                 __m256i avec0 = load(A + lda * (ii + 0) + l); | ||||||
|  |                 __m256i avec1 = load(A + lda * (ii + 1) + l); | ||||||
|  |                 __m256i avec2 = load(A + lda * (ii + 2) + l); | ||||||
|  |                 __m256i avec3 = load(A + lda * (ii + 3) + l); | ||||||
|  |                 for (int64_t j = 0; j < RN; ++j) { | ||||||
|  |                         __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d)); | ||||||
|  |                         // Computation of product of delta values for four blocks and replicate it across 256 bit lane | ||||||
|  |                         __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db)); | ||||||
|  |                         dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); | ||||||
|  |                         // Computation of dot product and multiplication with appropriate delta value products | ||||||
|  |                         Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0), | ||||||
|  |                                     updot(_mm256_sign_epi8(avec0, avec0), | ||||||
|  |                                           _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)), | ||||||
|  |                                     Cv[j][0]); | ||||||
|  |                         Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85), | ||||||
|  |                                     updot(_mm256_sign_epi8(avec1, avec1), | ||||||
|  |                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)), | ||||||
|  |                                     Cv[j][1]); | ||||||
|  |                         Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170), | ||||||
|  |                                     updot(_mm256_sign_epi8(avec2, avec2), | ||||||
|  |                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)), | ||||||
|  |                                     Cv[j][2]); | ||||||
|  |                         Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255), | ||||||
|  |                                     updot(_mm256_sign_epi8(avec3, avec3), | ||||||
|  |                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)), | ||||||
|  |                                     Cv[j][3]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             for (int64_t j = 0; j < RN; ++j) | ||||||
|  |                 for (int64_t i = 0; i < 4; ++i) | ||||||
|  |                     C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Templated functions for gemm of dimensions Mx4 | ||||||
|  |     template <int RM> | ||||||
|  |     NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) { | ||||||
|  |         int64_t ytiles = (m - m0) / RM; | ||||||
|  |         int64_t xtiles = (n - n0) / 4; | ||||||
|  |         int64_t tiles = xtiles * ytiles; | ||||||
|  |         int64_t duty = (tiles + nth - 1) / nth; | ||||||
|  |         int64_t start = duty * ith; | ||||||
|  |         int64_t end = start + duty; | ||||||
|  |         if (end > tiles) | ||||||
|  |             end = tiles; | ||||||
|  |         for (int64_t job = start; job < end; ++job) { | ||||||
|  |             int64_t ii = m0 + job / xtiles * RM; | ||||||
|  |             int64_t jj = n0 + job % xtiles * 4; | ||||||
|  |             __m256 Cv[4][RM] = {}; | ||||||
|  |             for (int64_t l = 0; l < k; ++l) { | ||||||
|  |                 uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d); | ||||||
|  |                 // Convert delta values for four blocks to float values | ||||||
|  |                 __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta)); | ||||||
|  |                 __m256i bvec0 = load(B + ldb * (jj + 0) + l); | ||||||
|  |                 __m256i bvec1 = load(B + ldb * (jj + 1) + l); | ||||||
|  |                 __m256i bvec2 = load(B + ldb * (jj + 2) + l); | ||||||
|  |                 __m256i bvec3 = load(B + ldb * (jj + 3) + l); | ||||||
|  |                 for (int64_t i = 0; i < RM; ++i) { | ||||||
|  |                     __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d))); | ||||||
|  |                     // Computation of product of delta values for four blocks and replicate it across 256 bit lane | ||||||
|  |                     __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db)); | ||||||
|  |                     dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); | ||||||
|  |                     // Computation of dot product and multiplication with appropriate delta value products | ||||||
|  |                     Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0), | ||||||
|  |                                     updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), | ||||||
|  |                                                             load(A + lda * (ii + i) + l)), | ||||||
|  |                                             _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))), | ||||||
|  |                                     Cv[0][i]); | ||||||
|  |                     Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85), | ||||||
|  |                                     updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), | ||||||
|  |                                                             load(A + lda * (ii + i) + l)), | ||||||
|  |                                             _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))), | ||||||
|  |                                     Cv[1][i]); | ||||||
|  |                     Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170), | ||||||
|  |                                     updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), | ||||||
|  |                                                             load(A + lda * (ii + i) + l)), | ||||||
|  |                                             _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))), | ||||||
|  |                                     Cv[2][i]); | ||||||
|  |                     Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255), | ||||||
|  |                                     updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), | ||||||
|  |                                                             load(A + lda * (ii + i) + l)), | ||||||
|  |                                             _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))), | ||||||
|  |                                     Cv[3][i]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             for (int64_t j = 0; j < 4; ++j) | ||||||
|  |                 for (int64_t i = 0; i < RM; ++i) | ||||||
|  |                     C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|     template <int RM, int RN> |     template <int RM, int RN> | ||||||
|     NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { |     NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { | ||||||
|         int64_t ytiles = (m - m0) / RM; |         int64_t ytiles = (m - m0) / RM; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Srihari-mcw
					Srihari-mcw