mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CPU/CUDA: fix (GQA) mul mat back, add CUDA support (#11380)
This commit is contained in:
		| @@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st | |||||||
|         case GGML_OP_IM2COL_BACK: |         case GGML_OP_IM2COL_BACK: | ||||||
|             return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; |             return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; | ||||||
|         case GGML_OP_OUT_PROD: |         case GGML_OP_OUT_PROD: | ||||||
|             return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32; |             return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && | ||||||
|  |                 src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; | ||||||
|         default: |         default: | ||||||
|             return true; |             return true; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s | |||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| static __global__ void k_repeat_back( | static __global__ void k_repeat_back( | ||||||
|     const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, |     const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, | ||||||
|     const int64_t ne0, const int64_t ne1, const int64_t ne2) { |     const size_t s00, const size_t s01, const size_t s02, const size_t s03, | ||||||
|  |     const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) { | ||||||
|  |  | ||||||
|     const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; |     const int64_t tid0  = int64_t(blockIdx.x)*blockDim.x + threadIdx.x; | ||||||
|     const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y; |     const int64_t tid1  = int64_t(blockIdx.y)*blockDim.y + threadIdx.y; | ||||||
|     const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z; |     const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z; | ||||||
|  |     const int64_t tid2  = tid23 % ne2; | ||||||
|  |     const int64_t tid3  = tid23 / ne2; | ||||||
|  |  | ||||||
|     if (tid0 >= ne0) { |     if (tid0 >= ne0) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     T sum = 0; |     T sum = 0; | ||||||
|  |     for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { | ||||||
|         for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { |         for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { | ||||||
|             for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { |             for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { | ||||||
|                 for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { |                 for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { | ||||||
|                 sum += src[i2*ne01*ne00 + i1*ne00 + i0]; |                     sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00]; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; |     } | ||||||
|  |     dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; | ||||||
| } | } | ||||||
|  |  | ||||||
| template<float (*bin_op)(const float, const float)> | template<float (*bin_op)(const float, const float)> | ||||||
| @@ -274,12 +279,14 @@ struct bin_bcast_cuda { | |||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| static void repeat_back_cuda( | static void repeat_back_cuda( | ||||||
|     const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, |     const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, | ||||||
|     const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) { |     const size_t s00, const size_t s01, const size_t s02, const size_t s03, | ||||||
|  |     const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { | ||||||
|  |  | ||||||
|     const dim3 block_dims(WARP_SIZE, 1, 1); |     const dim3 block_dims(WARP_SIZE, 1, 1); | ||||||
|     const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2); |     const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3); | ||||||
|     k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2); |     k_repeat_back<T><<<block_nums, block_dims, 0, stream>>> | ||||||
|  |         (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3); | ||||||
| } | } | ||||||
|  |  | ||||||
| template<class op> | template<class op> | ||||||
| @@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst | |||||||
|     const ggml_tensor * src0 = dst->src[0]; |     const ggml_tensor * src0 = dst->src[0]; | ||||||
|  |  | ||||||
|     GGML_ASSERT(src0->type == dst->type); |     GGML_ASSERT(src0->type == dst->type); | ||||||
|     GGML_ASSERT(ggml_is_contiguous(src0)); |  | ||||||
|     GGML_ASSERT(ggml_is_contiguous(dst)); |     GGML_ASSERT(ggml_is_contiguous(dst)); | ||||||
|     GGML_ASSERT(ggml_can_repeat(dst, src0)); |     GGML_ASSERT(ggml_can_repeat(dst, src0)); | ||||||
|  |  | ||||||
|     cudaStream_t stream = ctx.stream(); |     cudaStream_t stream = ctx.stream(); | ||||||
|  |  | ||||||
|     const int64_t ne00 = src0->ne[0]; |     GGML_TENSOR_UNARY_OP_LOCALS; | ||||||
|     const int64_t ne01 = src0->ne[1]; |  | ||||||
|     const int64_t ne02 = src0->ne[2]; |  | ||||||
|     GGML_ASSERT(src0->ne[3] == 1); |  | ||||||
|  |  | ||||||
|     const int64_t ne0 = dst->ne[0]; |     GGML_ASSERT(ne2*ne3 <= (1 << 15)); | ||||||
|     const int64_t ne1 = dst->ne[1]; |  | ||||||
|     const int64_t ne2 = dst->ne[2]; |     const size_t ts = ggml_type_size(src0->type); | ||||||
|     GGML_ASSERT(dst->ne[3] == 1); |     const size_t s00 = nb00 / ts; | ||||||
|  |     const size_t s01 = nb01 / ts; | ||||||
|  |     const size_t s02 = nb02 / ts; | ||||||
|  |     const size_t s03 = nb03 / ts; | ||||||
|  |  | ||||||
|     switch (dst->type) { |     switch (dst->type) { | ||||||
|         case GGML_TYPE_F32: { |         case GGML_TYPE_F32: { | ||||||
|             const float * src0_d = (const float *) src0->data; |             const float * src0_d = (const float *) src0->data; | ||||||
|             float       * dst_d  = (float       *) dst->data; |             float       * dst_d  = (float       *) dst->data; | ||||||
|             repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream); |             repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream); | ||||||
|         } break; |         } break; | ||||||
|         default: { |         default: { | ||||||
|             GGML_ASSERT(false); |             GGML_ASSERT(false); | ||||||
|   | |||||||
| @@ -3002,7 +3002,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | |||||||
|                 return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; |                 return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; | ||||||
|             } break; |             } break; | ||||||
|         case GGML_OP_REPEAT_BACK: |         case GGML_OP_REPEAT_BACK: | ||||||
|                 return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1; |                 return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15); | ||||||
|         case GGML_OP_CONCAT: |         case GGML_OP_CONCAT: | ||||||
|             { |             { | ||||||
|                 ggml_type src0_type = op->src[0]->type; |                 ggml_type src0_type = op->src[0]->type; | ||||||
|   | |||||||
| @@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |||||||
|  |  | ||||||
|     CUBLAS_CHECK(cublasSetStream(handle, stream)); |     CUBLAS_CHECK(cublasSetStream(handle, stream)); | ||||||
|  |  | ||||||
|  |     const int64_t lda = nb01 / sizeof(float); | ||||||
|  |     const int64_t ldc = nb1  / sizeof(float); | ||||||
|  |  | ||||||
|     const bool src1_T = ggml_is_transposed(src1); |     const bool src1_T = ggml_is_transposed(src1); | ||||||
|     const cublasOperation_t src1_cublas_op =  src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; |     const cublasOperation_t src1_cublas_op =  src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; | ||||||
|     const int64_t           ldb            = (src1_T ?        nb10 :        nb11) /  sizeof(float); |     const int64_t           ldb            = (src1_T ?        nb10 :        nb11) /  sizeof(float); | ||||||
| @@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |||||||
|             CUBLAS_CHECK( |             CUBLAS_CHECK( | ||||||
|                 cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, |                 cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, | ||||||
|                         ne0, ne1, ne01, |                         ne0, ne1, ne01, | ||||||
|                         &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00, |                         &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, | ||||||
|                                 src1_d +  i3      *s13 +  i2      *s12, ldb, |                                 src1_d +  i3      *s13 +  i2      *s12, ldb, | ||||||
|                         &beta,  dst_d  +  i3      *s3  +  i2      *s2,  ne0)); |                         &beta,  dst_d  +  i3      *s3  +  i2      *s2,  ldc)); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|   | |||||||
| @@ -5339,7 +5339,7 @@ static void ggml_compute_backward( | |||||||
|         } break; |         } break; | ||||||
|         case GGML_OP_MUL: { |         case GGML_OP_MUL: { | ||||||
|             if (src0_needs_grads) { |             if (src0_needs_grads) { | ||||||
|                 ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad)); |                 ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1)); | ||||||
|             } |             } | ||||||
|             if (src1_needs_grads) { |             if (src1_needs_grads) { | ||||||
|                 struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); |                 struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); | ||||||
| @@ -5431,21 +5431,25 @@ static void ggml_compute_backward( | |||||||
|             // src1.shape   [n,p,qq,rr] |             // src1.shape   [n,p,qq,rr] | ||||||
|  |  | ||||||
|             if (src0_needs_grads) { |             if (src0_needs_grads) { | ||||||
|                 struct ggml_tensor * s1_tg = |                 GGML_ASSERT(grad->ne[2] == src1->ne[2]); | ||||||
|  |                 GGML_ASSERT(grad->ne[3] == src1->ne[3]); | ||||||
|  |                 struct ggml_tensor * tmp = | ||||||
|                     ggml_out_prod(ctx, // [n,m,qq,rr] |                     ggml_out_prod(ctx, // [n,m,qq,rr] | ||||||
|                         src1,          // [n,p,qq,rr] |                         src1,          // [n,p,qq,rr] | ||||||
|                         grad);         // [m,p,qq,rr] |                         grad);         // [m,p,qq,rr] | ||||||
|                 const int64_t qq = s1_tg->ne[2]; |                 if (!ggml_are_same_shape(tmp, src0)) { | ||||||
|                 const int64_t rr = s1_tg->ne[3]; |                     GGML_ASSERT(tmp->ne[0] == src0->ne[0]); | ||||||
|                 const int64_t q1 = src0->ne[2]; |                     GGML_ASSERT(tmp->ne[1] == src0->ne[1]); | ||||||
|                 const int64_t r1 = src0->ne[3]; |                     GGML_ASSERT(tmp->ne[3] == 1); | ||||||
|                 const bool ne2_broadcasted = qq > q1; |  | ||||||
|                 const bool ne3_broadcasted = rr > r1; |                     const int64_t nr2 = tmp->ne[2] / src0->ne[2]; | ||||||
|                 if (ne2_broadcasted || ne3_broadcasted) { |                     const size_t nb2 = tmp->nb[2] * nr2; | ||||||
|                     // sum broadcast repetitions of s1_tg into shape of src0 |                     const size_t nb3 = tmp->nb[2]; | ||||||
|                     s1_tg = ggml_repeat_back(ctx, s1_tg, src0); |  | ||||||
|  |                     tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0); | ||||||
|  |                     tmp = ggml_repeat_back(ctx, tmp, src0); | ||||||
|                 } |                 } | ||||||
|                 ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/); |                 ggml_add_or_set(ctx, cgraph, isrc0, tmp); | ||||||
|             } |             } | ||||||
|             if (src1_needs_grads) { |             if (src1_needs_grads) { | ||||||
|                 ggml_add_or_set(ctx, cgraph, isrc1, |                 ggml_add_or_set(ctx, cgraph, isrc1, | ||||||
| @@ -5514,7 +5518,9 @@ static void ggml_compute_backward( | |||||||
|             if (src0_needs_grads) { |             if (src0_needs_grads) { | ||||||
|                 GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); |                 GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); | ||||||
|                 GGML_ASSERT(ggml_is_contiguous(grad)); |                 GGML_ASSERT(ggml_is_contiguous(grad)); | ||||||
|                 ggml_add_or_set(ctx, cgraph, isrc0, grad); |                 GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0)); | ||||||
|  |                 ggml_add_or_set(ctx, cgraph, isrc0, | ||||||
|  |                     ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0)); | ||||||
|             } |             } | ||||||
|         } break; |         } break; | ||||||
|         case GGML_OP_RESHAPE: { |         case GGML_OP_RESHAPE: { | ||||||
|   | |||||||
| @@ -1302,6 +1302,59 @@ struct test_repeat : public test_case { | |||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | // GGML_OP_REPEAT_BACK | ||||||
|  | struct test_repeat_back : public test_case { | ||||||
|  |     const ggml_type type; | ||||||
|  |     const std::array<int64_t, 4> ne; | ||||||
|  |     const std::array<int, 4> nr; | ||||||
|  |     const bool v; // whether src is a noncontiguous view | ||||||
|  |  | ||||||
|  |     std::string vars() override { | ||||||
|  |         return VARS_TO_STR4(type, ne, nr, v); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     size_t op_size(ggml_tensor * t) override { | ||||||
|  |         return ggml_nbytes(t) * 2; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     test_repeat_back(ggml_type type = GGML_TYPE_F32, | ||||||
|  |             std::array<int64_t, 4> ne = {8, 6, 4, 2}, | ||||||
|  |             std::array<int, 4> nr = {2, 2, 2, 2}, | ||||||
|  |             bool v = false) | ||||||
|  |         : type(type), ne(ne), nr(nr), v(v) {} | ||||||
|  |  | ||||||
|  |     ggml_tensor * build_graph(ggml_context * ctx) override { | ||||||
|  |         ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]); | ||||||
|  |         ggml_set_name(src, "src"); | ||||||
|  |  | ||||||
|  |         if (v) { | ||||||
|  |             GGML_ASSERT(ne[0] % 2 == 0); | ||||||
|  |             GGML_ASSERT(ne[1] % 2 == 0); | ||||||
|  |             GGML_ASSERT(ne[2] % 2 == 0); | ||||||
|  |             GGML_ASSERT(ne[3] % 2 == 0); | ||||||
|  |             GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1); | ||||||
|  |             GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1); | ||||||
|  |             GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1); | ||||||
|  |             GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1); | ||||||
|  |  | ||||||
|  |             const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2; | ||||||
|  |             const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2; | ||||||
|  |             const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2; | ||||||
|  |             const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2; | ||||||
|  |  | ||||||
|  |             src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data()); | ||||||
|  |         ggml_set_name(target, "target"); | ||||||
|  |  | ||||||
|  |         ggml_tensor * out = ggml_repeat_back(ctx, src, target); | ||||||
|  |         ggml_set_name(out, "out"); | ||||||
|  |  | ||||||
|  |         return out; | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
| // GGML_OP_DUP | // GGML_OP_DUP | ||||||
| struct test_dup : public test_case { | struct test_dup : public test_case { | ||||||
|     const ggml_type type; |     const ggml_type type; | ||||||
| @@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case { | |||||||
|         return 5e-4; |         return 5e-4; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     int64_t grad_nmax() override { | ||||||
|  |         return 20000; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     uint64_t op_flops(ggml_tensor * t) override { |     uint64_t op_flops(ggml_tensor * t) override { | ||||||
|         GGML_UNUSED(t); |         GGML_UNUSED(t); | ||||||
|         return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1]; |         return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1]; | ||||||
| @@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case { | |||||||
|  |  | ||||||
|             a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]); |             a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]); | ||||||
|             b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]); |             b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]); | ||||||
|  |             if (!ggml_is_quantized(type_a)) { | ||||||
|  |                 if (bs[1] == 1 && nr[1] == 1) { | ||||||
|                     ggml_set_param(ctx, a); |                     ggml_set_param(ctx, a); | ||||||
|  |                 } | ||||||
|                 ggml_set_param(ctx, b); |                 ggml_set_param(ctx, b); | ||||||
|  |             } | ||||||
|             ggml_set_name(a, "a"); |             ggml_set_name(a, "a"); | ||||||
|             ggml_set_name(b, "b"); |             ggml_set_name(b, "b"); | ||||||
|  |  | ||||||
| @@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case { | |||||||
|         } else { |         } else { | ||||||
|             a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0],       bs[1]); |             a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0],       bs[1]); | ||||||
|             b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); |             b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); | ||||||
|  |             if (!ggml_is_quantized(type_a)) { | ||||||
|  |                 if (bs[1] == 1 && nr[1] == 1) { | ||||||
|                     ggml_set_param(ctx, a); |                     ggml_set_param(ctx, a); | ||||||
|  |                 } | ||||||
|                 ggml_set_param(ctx, b); |                 ggml_set_param(ctx, b); | ||||||
|  |             } | ||||||
|             ggml_set_name(a, "a"); |             ggml_set_name(a, "a"); | ||||||
|             ggml_set_name(b, "b"); |             ggml_set_name(b, "b"); | ||||||
|         } |         } | ||||||
| @@ -3798,6 +3863,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() { | |||||||
|         test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2})); |         test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2})); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     for (bool view : {false, true}) { | ||||||
|  |         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view)); | ||||||
|  |         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); | ||||||
|  |         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view)); | ||||||
|  |         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view)); | ||||||
|  |         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); | ||||||
|  |         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view)); | ||||||
|  |         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); |     test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); | ||||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); |     test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); | ||||||
|     test_cases.emplace_back(new test_dup(GGML_TYPE_I32)); |     test_cases.emplace_back(new test_dup(GGML_TYPE_I32)); | ||||||
| @@ -3920,20 +3995,24 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() { | |||||||
|         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { |         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { | ||||||
|             // test cases without permutation |             // test cases without permutation | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 1})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 1})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10,  1}, {1, 1})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {2, 1})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10,  1}, {2, 1})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {1, 1}, {1, 2})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10, 10}, {1, 1})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 1}, {1, 1})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10, 10}, {2, 1})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 1}, {2, 1})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10, 10}, {1, 2})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {1, 1})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {10, 10}, {2, 2})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {2, 1})); | ||||||
|  |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {1, 2})); | ||||||
|  |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {3, 2}, {2, 2})); | ||||||
|  |  | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {1, 1})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {2, 1})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1})); | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1})); | ||||||
|  |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2})); | ||||||
|  |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2})); | ||||||
|  |  | ||||||
|             // test cases with permutation |             // test cases with permutation | ||||||
|             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); |             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3})); | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler