mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	sgemm : improved Q4_0 and Q8_0 performance via 4xN and Mx4 gemm (#8908)
This commit is contained in:
		| @@ -606,17 +606,29 @@ class tinyBLAS_Q0_AVX { | ||||
|         case 0x44: | ||||
|             mc = 4; | ||||
|             nc = 4; | ||||
| #if defined(__AVX2__) && defined(__F16C__) | ||||
|             gemm4xN<4>(m0, m, n0, n); | ||||
| #else | ||||
|             gemm<4, 4>(m0, m, n0, n); | ||||
| #endif | ||||
|             break; | ||||
|         case 0x43: | ||||
|             mc = 4; | ||||
|             nc = 3; | ||||
| #if defined(__AVX2__) && defined(__F16C__) | ||||
|             gemm4xN<3>(m0, m, n0, n); | ||||
| #else | ||||
|             gemm<4, 3>(m0, m, n0, n); | ||||
| #endif | ||||
|             break; | ||||
|         case 0x34: | ||||
|             mc = 3; | ||||
|             nc = 4; | ||||
| #if defined(__AVX2__) && defined(__F16C__) | ||||
|             gemmMx4<3>(m0, m, n0, n); | ||||
| #else | ||||
|             gemm<3, 4>(m0, m, n0, n); | ||||
| #endif | ||||
|             break; | ||||
|         case 0x33: | ||||
|             mc = 3; | ||||
| @@ -626,12 +638,20 @@ class tinyBLAS_Q0_AVX { | ||||
|         case 0x42: | ||||
|             mc = 4; | ||||
|             nc = 2; | ||||
| #if defined(__AVX2__) && defined(__F16C__) | ||||
|             gemm4xN<2>(m0, m, n0, n); | ||||
| #else | ||||
|             gemm<4, 2>(m0, m, n0, n); | ||||
| #endif | ||||
|             break; | ||||
|         case 0x24: | ||||
|             mc = 2; | ||||
|             nc = 4; | ||||
| #if defined(__AVX2__) && defined(__F16C__) | ||||
|             gemmMx4<2>(m0, m, n0, n); | ||||
| #else | ||||
|             gemm<2, 4>(m0, m, n0, n); | ||||
| #endif | ||||
|             break; | ||||
| #else | ||||
|         case 0x44: | ||||
| @@ -639,13 +659,21 @@ class tinyBLAS_Q0_AVX { | ||||
|         case 0x42: | ||||
|             mc = 4; | ||||
|             nc = 2; | ||||
| #if defined(__AVX2__) && defined(__F16C__) | ||||
|             gemm4xN<2>(m0, m, n0, n); | ||||
| #else | ||||
|             gemm<4, 2>(m0, m, n0, n); | ||||
| #endif | ||||
|             break; | ||||
|         case 0x34: | ||||
|         case 0x24: | ||||
|             mc = 2; | ||||
|             nc = 4; | ||||
| #if defined(__AVX2__) && defined(__F16C__) | ||||
|             gemmMx4<2>(m0, m, n0, n); | ||||
| #else | ||||
|             gemm<2, 4>(m0, m, n0, n); | ||||
| #endif | ||||
|             break; | ||||
|         case 0x33: | ||||
| #endif | ||||
| @@ -662,7 +690,11 @@ class tinyBLAS_Q0_AVX { | ||||
|         case 0x41: | ||||
|             mc = 4; | ||||
|             nc = 1; | ||||
| #if defined(__AVX2__) && defined(__F16C__) | ||||
|             gemm4xN<1>(m0, m, n0, n); | ||||
| #else | ||||
|             gemm<4, 1>(m0, m, n0, n); | ||||
| #endif | ||||
|             break; | ||||
|         case 0x22: | ||||
|             mc = 2; | ||||
| @@ -672,7 +704,11 @@ class tinyBLAS_Q0_AVX { | ||||
|         case 0x14: | ||||
|             mc = 1; | ||||
|             nc = 4; | ||||
| #if defined(__AVX2__) && defined(__F16C__) | ||||
|             gemmMx4<1>(m0, m, n0, n); | ||||
| #else | ||||
|             gemm<1, 4>(m0, m, n0, n); | ||||
| #endif | ||||
|             break; | ||||
|         case 0x31: | ||||
|             mc = 3; | ||||
| @@ -708,6 +744,119 @@ class tinyBLAS_Q0_AVX { | ||||
|         mnpack(m0, m, np, n); | ||||
|     } | ||||
|  | ||||
| #if defined(__AVX2__) && defined(__F16C__) | ||||
| // Templated functions for gemm of dimensions 4xN | ||||
|     template <int RN> | ||||
|     NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) { | ||||
|         int64_t ytiles = (m - m0) / 4; | ||||
|         int64_t xtiles = (n - n0) / RN; | ||||
|         int64_t tiles = xtiles * ytiles; | ||||
|         int64_t duty = (tiles + nth - 1) / nth; | ||||
|         int64_t start = duty * ith; | ||||
|         int64_t end = start + duty; | ||||
|         if (end > tiles) | ||||
|             end = tiles; | ||||
|         for (int64_t job = start; job < end; ++job) { | ||||
|             int64_t ii = m0 + job / xtiles * 4; | ||||
|             int64_t jj = n0 + job % xtiles * RN; | ||||
|             __m256 Cv[RN][4] = {}; | ||||
|             for (int64_t l = 0; l < k; ++l) { | ||||
|                 uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d); | ||||
|                 // Convert delta values for four blocks to float values | ||||
|                 __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta)); | ||||
|                 __m256i avec0 = load(A + lda * (ii + 0) + l); | ||||
|                 __m256i avec1 = load(A + lda * (ii + 1) + l); | ||||
|                 __m256i avec2 = load(A + lda * (ii + 2) + l); | ||||
|                 __m256i avec3 = load(A + lda * (ii + 3) + l); | ||||
|                 for (int64_t j = 0; j < RN; ++j) { | ||||
|                         __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d)); | ||||
|                         // Computation of product of delta values for four blocks and replicate it across 256 bit lane | ||||
|                         __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db)); | ||||
|                         dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); | ||||
|                         // Computation of dot product and multiplication with appropriate delta value products | ||||
|                         Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0), | ||||
|                                     updot(_mm256_sign_epi8(avec0, avec0), | ||||
|                                           _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)), | ||||
|                                     Cv[j][0]); | ||||
|                         Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85), | ||||
|                                     updot(_mm256_sign_epi8(avec1, avec1), | ||||
|                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)), | ||||
|                                     Cv[j][1]); | ||||
|                         Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170), | ||||
|                                     updot(_mm256_sign_epi8(avec2, avec2), | ||||
|                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)), | ||||
|                                     Cv[j][2]); | ||||
|                         Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255), | ||||
|                                     updot(_mm256_sign_epi8(avec3, avec3), | ||||
|                                             _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)), | ||||
|                                     Cv[j][3]); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             for (int64_t j = 0; j < RN; ++j) | ||||
|                 for (int64_t i = 0; i < 4; ++i) | ||||
|                     C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // Templated functions for gemm of dimensions Mx4 | ||||
|     template <int RM> | ||||
|     NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) { | ||||
|         int64_t ytiles = (m - m0) / RM; | ||||
|         int64_t xtiles = (n - n0) / 4; | ||||
|         int64_t tiles = xtiles * ytiles; | ||||
|         int64_t duty = (tiles + nth - 1) / nth; | ||||
|         int64_t start = duty * ith; | ||||
|         int64_t end = start + duty; | ||||
|         if (end > tiles) | ||||
|             end = tiles; | ||||
|         for (int64_t job = start; job < end; ++job) { | ||||
|             int64_t ii = m0 + job / xtiles * RM; | ||||
|             int64_t jj = n0 + job % xtiles * 4; | ||||
|             __m256 Cv[4][RM] = {}; | ||||
|             for (int64_t l = 0; l < k; ++l) { | ||||
|                 uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d); | ||||
|                 // Convert delta values for four blocks to float values | ||||
|                 __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta)); | ||||
|                 __m256i bvec0 = load(B + ldb * (jj + 0) + l); | ||||
|                 __m256i bvec1 = load(B + ldb * (jj + 1) + l); | ||||
|                 __m256i bvec2 = load(B + ldb * (jj + 2) + l); | ||||
|                 __m256i bvec3 = load(B + ldb * (jj + 3) + l); | ||||
|                 for (int64_t i = 0; i < RM; ++i) { | ||||
|                     __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d))); | ||||
|                     // Computation of product of delta values for four blocks and replicate it across 256 bit lane | ||||
|                     __m256 dvec =  _mm256_castps128_ps256(_mm_mul_ps(da, db)); | ||||
|                     dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); | ||||
|                     // Computation of dot product and multiplication with appropriate delta value products | ||||
|                     Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0), | ||||
|                                     updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), | ||||
|                                                             load(A + lda * (ii + i) + l)), | ||||
|                                             _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))), | ||||
|                                     Cv[0][i]); | ||||
|                     Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85), | ||||
|                                     updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), | ||||
|                                                             load(A + lda * (ii + i) + l)), | ||||
|                                             _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))), | ||||
|                                     Cv[1][i]); | ||||
|                     Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170), | ||||
|                                     updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), | ||||
|                                                             load(A + lda * (ii + i) + l)), | ||||
|                                             _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))), | ||||
|                                     Cv[2][i]); | ||||
|                     Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255), | ||||
|                                     updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), | ||||
|                                                             load(A + lda * (ii + i) + l)), | ||||
|                                             _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))), | ||||
|                                     Cv[3][i]); | ||||
|                 } | ||||
|             } | ||||
|             for (int64_t j = 0; j < 4; ++j) | ||||
|                 for (int64_t i = 0; i < RM; ++i) | ||||
|                     C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     template <int RM, int RN> | ||||
|     NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { | ||||
|         int64_t ytiles = (m - m0) / RM; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Srihari-mcw
					Srihari-mcw