mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	sgemm : AVX Q4_0 and Q8_0 (#6891)
* basic avx implementation * style * combine denibble with load * reduce 256 to 128 (and back!) conversions * sse load * Update sgemm.cpp * oops oops
This commit is contained in:
		
							
								
								
									
										77
									
								
								sgemm.cpp
									
									
									
									
									
								
							
							
						
						
									
										77
									
								
								sgemm.cpp
									
									
									
									
									
								
							| @@ -1,6 +1,3 @@ | ||||
| // -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*- | ||||
| // vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi | ||||
| // | ||||
| // Copyright 2024 Mozilla Foundation | ||||
| // | ||||
| // Permission is hereby granted, free of charge, to any person obtaining | ||||
| @@ -585,15 +582,15 @@ class tinyBLAS_Q0_ARM { | ||||
| }; | ||||
| #endif // __ARM_FEATURE_DOTPROD | ||||
|  | ||||
| #if defined(__AVX2__) || defined(__AVX512F__) | ||||
| #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) | ||||
| template <typename TA, typename TB, typename TC> | ||||
| class tinyBLAS_Q0_AVX2 { | ||||
| class tinyBLAS_Q0_AVX { | ||||
|   public: | ||||
|     tinyBLAS_Q0_AVX2(int64_t k, | ||||
|                      const TA *A, int64_t lda, | ||||
|                      const TB *B, int64_t ldb, | ||||
|                      TC *C, int64_t ldc, | ||||
|                      int ith, int nth) | ||||
|     tinyBLAS_Q0_AVX(int64_t k, | ||||
|                     const TA *A, int64_t lda, | ||||
|                     const TB *B, int64_t ldb, | ||||
|                     TC *C, int64_t ldc, | ||||
|                     int ith, int nth) | ||||
|         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { | ||||
|     } | ||||
|  | ||||
| @@ -728,14 +725,34 @@ class tinyBLAS_Q0_AVX2 { | ||||
|             __m256 Cv[RN][RM] = {}; | ||||
|             for (int64_t l = 0; l < k; ++l) | ||||
|                 for (int64_t j = 0; j < RN; ++j) | ||||
|                     for (int64_t i = 0; i < RM; ++i) | ||||
|                     for (int64_t i = 0; i < RM; ++i) { | ||||
| #if defined(__AVX2__) | ||||
|                         __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), | ||||
|                                                               load(A + lda * (ii + i) + l)), | ||||
|                                              _mm256_sign_epi8(load(B + ldb * (jj + j) + l), | ||||
|                                                               load(A + lda * (ii + i) + l))); | ||||
| #else | ||||
|                         __m128i ali0 = load0(A + lda * (ii + i) + l); | ||||
|                         __m128i ali1 = load1(A + lda * (ii + i) + l); | ||||
|                         __m128i blj0 = load0(B + ldb * (jj + j) + l); | ||||
|                         __m128i blj1 = load1(B + ldb * (jj + j) + l); | ||||
|  | ||||
|                         __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); | ||||
|                         __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); | ||||
|                         __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); | ||||
|                         __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); | ||||
|  | ||||
|                         // updot | ||||
|                         const __m128i oneFill = _mm_set1_epi16(1); | ||||
|                         __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); | ||||
|                         __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); | ||||
|                         __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); | ||||
| #endif | ||||
|                         Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * | ||||
|                                                        unhalf(B[ldb * (jj + j) + l].d)), | ||||
|                                         updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), | ||||
|                                                                load(A + lda * (ii + i) + l)), | ||||
|                                               _mm256_sign_epi8(load(B + ldb * (jj + j) + l), | ||||
|                                                                load(A + lda * (ii + i) + l))), | ||||
|                                         Cv[j][i]); | ||||
|                                                        udTmp, | ||||
|                                                        Cv[j][i]); | ||||
|                     } | ||||
|             for (int64_t j = 0; j < RN; ++j) | ||||
|                 for (int64_t i = 0; i < RM; ++i) | ||||
|                     C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); | ||||
| @@ -746,10 +763,28 @@ class tinyBLAS_Q0_AVX2 { | ||||
|         return _mm256_loadu_si256((const __m256i *)b->qs); | ||||
|     } | ||||
|  | ||||
|     inline __m128i load0(const block_q8_0 *b) { | ||||
|         return _mm_loadu_si128((const __m128i *)b->qs); | ||||
|     } | ||||
|  | ||||
|     inline __m128i load1(const block_q8_0 *b) { | ||||
|         return _mm_loadu_si128(((const __m128i *)b->qs) + 1); | ||||
|     } | ||||
|  | ||||
|     inline __m256i load(const block_q4_0 *b) { | ||||
|         return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); | ||||
|     } | ||||
|  | ||||
|     inline __m128i load0(const block_q4_0 *b) { | ||||
|         const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); | ||||
|         return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); | ||||
|     } | ||||
|  | ||||
|     inline __m128i load1(const block_q4_0 *b) { | ||||
|         const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); | ||||
|         return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); | ||||
|     } | ||||
|  | ||||
|     inline __m256 updot(__m256i u, __m256i s) { | ||||
|         __m256i res; | ||||
| #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) | ||||
| @@ -777,7 +812,7 @@ class tinyBLAS_Q0_AVX2 { | ||||
|     const int ith; | ||||
|     const int nth; | ||||
| }; | ||||
| #endif // __AVX2__ | ||||
| #endif // __AVX__ | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| @@ -928,8 +963,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda | ||||
|     case GGML_TYPE_Q8_0: { | ||||
|         if (Btype != GGML_TYPE_Q8_0) | ||||
|            return false; | ||||
| #if defined(__AVX2__) || defined(__AVX512F__) | ||||
|         tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{ | ||||
| #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) | ||||
|         tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{ | ||||
|             k, (const block_q8_0 *)A, lda, | ||||
|             (const block_q8_0 *)B, ldb, | ||||
|             (float *)C, ldc, | ||||
| @@ -952,8 +987,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda | ||||
|     case GGML_TYPE_Q4_0: { | ||||
|         if (Btype != GGML_TYPE_Q8_0) | ||||
|             return false; | ||||
| #if defined(__AVX2__) || defined(__AVX512F__) | ||||
|         tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{ | ||||
| #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) | ||||
|         tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{ | ||||
|             k, (const block_q4_0 *)A, lda, | ||||
|             (const block_q8_0 *)B, ldb, | ||||
|             (float *)C, ldc, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Eve
					Eve