mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: fix 1D im2col, add tests (ggml/993)
This commit is contained in:
		 Johannes Gäßler
					Johannes Gäßler
				
			
				
					committed by
					
						 Georgi Gerganov
						Georgi Gerganov
					
				
			
			
				
	
			
			
			 Georgi Gerganov
						Georgi Gerganov
					
				
			
						parent
						
							c19af0acb1
						
					
				
				
					commit
					80273a306d
				
			| @@ -3141,7 +3141,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | |||||||
|         case GGML_OP_ROPE: |         case GGML_OP_ROPE: | ||||||
|             return ggml_is_contiguous(op->src[0]); |             return ggml_is_contiguous(op->src[0]); | ||||||
|         case GGML_OP_IM2COL: |         case GGML_OP_IM2COL: | ||||||
|             return op->src[0]->type == GGML_TYPE_F16; |  | ||||||
|         case GGML_OP_POOL_2D: |         case GGML_OP_POOL_2D: | ||||||
|         case GGML_OP_SUM: |         case GGML_OP_SUM: | ||||||
|         case GGML_OP_SUM_ROWS: |         case GGML_OP_SUM_ROWS: | ||||||
|   | |||||||
| @@ -91,9 +91,9 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |||||||
|     const int64_t OH = is_2D ? dst->ne[2] : 1; |     const int64_t OH = is_2D ? dst->ne[2] : 1; | ||||||
|     const int64_t OW =         dst->ne[1]; |     const int64_t OW =         dst->ne[1]; | ||||||
|  |  | ||||||
|     const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 |     const size_t  delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 | ||||||
|     const int64_t batch = src1->ne[3]; |     const int64_t batch        = src1->ne[is_2D ? 3 : 2]; | ||||||
|     const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 |     const size_t  batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 | ||||||
|  |  | ||||||
|     if(dst->type == GGML_TYPE_F16) { |     if(dst->type == GGML_TYPE_F16) { | ||||||
|         im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); |         im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); | ||||||
|   | |||||||
| @@ -3308,15 +3308,41 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32)); |     // im2col 1D | ||||||
|     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32)); |  | ||||||
|     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16)); |  | ||||||
|     // test cases for 1D im2col |  | ||||||
|     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); |     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); | ||||||
|     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); |     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); | ||||||
|     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); |     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false)); | ||||||
|  |     for (int s0 : {1, 3}) { | ||||||
|  |         for (int p0 : {0, 3}) { | ||||||
|  |             for (int d0 : {1, 3}) { | ||||||
|  |                 test_cases.emplace_back(new test_im2col( | ||||||
|  |                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 2, 2, 1}, {3, 2, 2, 1}, | ||||||
|  |                     s0, 0, p0, 0, d0, 0, false)); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     // test cases for 2D im2col |     // im2col 2D | ||||||
|  |     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32)); | ||||||
|  |     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32)); | ||||||
|  |     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16)); | ||||||
|  |     for (int s0 : {1, 3}) { | ||||||
|  |         for (int s1 : {1, 3}) { | ||||||
|  |             for (int p0 : {0, 3}) { | ||||||
|  |                 for (int p1 : {0, 3}) { | ||||||
|  |                     for (int d0 : {1, 3}) { | ||||||
|  |                         for (int d1 : {1, 3}) { | ||||||
|  |                             test_cases.emplace_back(new test_im2col( | ||||||
|  |                                 GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 2, 2}, {3, 3, 2, 2}, | ||||||
|  |                                 s0, s1, p0, p1, d0, d1, true)); | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // extra tests for im2col 2D | ||||||
|     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true)); |     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 32}, {3, 3, 1, 32}, 1, 1, 1, 1, 1, 1, true)); | ||||||
|     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true)); |     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 32}, {3, 3, 2, 32}, 1, 1, 1, 1, 1, 1, true)); | ||||||
|     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true)); |     test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 1024}, {3, 3, 1, 1024}, 1, 1, 1, 1, 1, 1, true)); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user