mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
cuda : fix supports_op condition for get_rows when number of blocks is too large (#15868)
* cuda : fix supports_op condition for get_rows when src1->ne2 > 1 ggml-ci * ggml : add comment about ggml_get_rows ggml-ci * cuda : add FIXME [no ci] * cuda : update support condition ggml-ci
This commit is contained in:
@@ -1529,7 +1529,11 @@ extern "C" {
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// supports 3D: a->ne[2] == b->ne[1]
|
||||
// supports 4D a:
|
||||
// a [n_embd, ne1, ne2, ne3]
|
||||
// b I32 [n_rows, ne2, ne3, 1]
|
||||
//
|
||||
// return [n_embd, n_rows, ne2, ne3]
|
||||
GGML_API struct ggml_tensor * ggml_get_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a, // data
|
||||
|
||||
@@ -3392,6 +3392,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
// FIXME: https://github.com/ggml-org/llama.cpp/pull/15868
|
||||
if (op->src[1]->ne[1]*op->src[1]->ne[2] > 65535) {
|
||||
return false;
|
||||
}
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
|
||||
@@ -3623,6 +3623,7 @@ struct ggml_tensor * ggml_get_rows(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b) {
|
||||
GGML_ASSERT(a->ne[2] == b->ne[1]);
|
||||
GGML_ASSERT(a->ne[3] == b->ne[2]);
|
||||
GGML_ASSERT(b->ne[3] == 1);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_I32);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user