mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-15 11:17:31 +00:00
ggml : implement set_rows with i32 index (#16159)
* implement set_rows with i32 index * template fix * test quantized path warnings-- * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * forgotten name change * deduplicate cuda/sycl and test-fix * indent++ * vulkan: support set_rows with i32 index type (#16162) * disable i32 index for webgpu for now --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
This commit is contained in:
@@ -4739,6 +4739,7 @@ void ggml_compute_forward_get_rows(
|
||||
//}
|
||||
}
|
||||
|
||||
template<typename idx_t>
|
||||
static void ggml_compute_forward_set_rows_f32(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@@ -4777,7 +4778,7 @@ static void ggml_compute_forward_set_rows_f32(
|
||||
const int64_t i11 = i02%ne11;
|
||||
const int64_t i10 = i;
|
||||
|
||||
const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
GGML_ASSERT(i1 >= 0 && i1 < ne1);
|
||||
|
||||
@@ -4794,11 +4795,18 @@ void ggml_compute_forward_set_rows(
|
||||
ggml_tensor * dst) {
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_set_rows_f32(params, dst);
|
||||
if (src1->type == GGML_TYPE_I64) {
|
||||
ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
|
||||
} else if (src1->type == GGML_TYPE_I32) {
|
||||
ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
|
||||
} else {
|
||||
GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user