mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-08 10:07:01 +00:00
* ggml : remove adding extra dim timestep embedding This commit updates the ggml_timestep_embedding function to no longer add an extra dimension when the specified dimension is odd. The motivation for this change is that this introduces an unnecessary dimension when the dimension is odd, which caused an issue in the kernels which were not expecting this extra dimension and it resulted in uninitialized memory for the second to last dimension. * ggml-cuda : fix padding in timestep embedding kernel This commit removes the zeroing out of the last dimension now that we are not adding the extra padding dimension. * ggml-metal : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel * ggml-opencl : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel. * ggml-sycl : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel. * ggml-vulkan : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel. * ggml-cpu : fix padding in timestep embedding function This commit removes the zeroing out of the last dimension now that we are not adding the extra padding dimension.
74 lines
2.6 KiB
C++
74 lines
2.6 KiB
C++
//
|
|
// MIT license
|
|
// Copyright (C) 2024 Intel Corporation
|
|
// SPDX-License-Identifier: MIT
|
|
//
|
|
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
|
|
#include "tsembd.hpp"
|
|
|
|
static void timestep_embedding_f32(
|
|
const float * timesteps, float * dst, const int nb1,
|
|
const int dim, const int max_period, const sycl::nd_item<3> &item_ct1) {
|
|
// item_ct1.get_group(1)(blockIDx.y): idx of timesteps->ne[0]
|
|
// item_ct1.get_group(2) (blockIDx.x): idx of ((dim + 1) / 2) / BLOCK_SIZE
|
|
int i = item_ct1.get_group(1);
|
|
int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
|
float * embed_data = (float *)((char *)dst + i*nb1);
|
|
|
|
int half = dim / 2;
|
|
|
|
if (dim % 2 != 0 && j == half) {
|
|
embed_data[2 * half] = 0.f;
|
|
}
|
|
|
|
if (j >= half) {
|
|
return;
|
|
}
|
|
|
|
float timestep = timesteps[i];
|
|
float freq = (float)sycl::native::exp(-(sycl::log((float)max_period)) * j / half);
|
|
float arg = timestep * freq;
|
|
embed_data[j] = sycl::cos(arg);
|
|
embed_data[j + half] = sycl::sin(arg);
|
|
}
|
|
|
|
static void timestep_embedding_f32_sycl(
|
|
const float * x, float * dst, const int ne00, const int nb1,
|
|
const int dim, const int max_period, const queue_ptr& stream) {
|
|
// As the kernel returns when thread.idx is larger than dim/2, the half_ceil does not need to pad
|
|
int half_ceil = dim / 2;
|
|
int num_blocks = (half_ceil + SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE;
|
|
sycl::range<3> block_dims(1, 1, SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE);
|
|
sycl::range<3> gridDim(1, ne00, num_blocks);
|
|
stream->parallel_for(
|
|
sycl::nd_range<3>(
|
|
gridDim * block_dims, block_dims),
|
|
[=](sycl::nd_item<3> item_ct1) {
|
|
timestep_embedding_f32(
|
|
x, dst, nb1, dim, max_period, item_ct1
|
|
);
|
|
});
|
|
}
|
|
|
|
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
|
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
|
|
const ggml_tensor * src0 = dst->src[0];
|
|
const float * src0_d = (const float *)src0->data;
|
|
float * dst_d = (float *)dst->data;
|
|
dpct::queue_ptr stream = ctx.stream();
|
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
|
|
const int dim = dst->op_params[0];
|
|
const int max_period = dst->op_params[1];
|
|
|
|
timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
|
|
}
|