ggml : fix padding in timestep embedding kernels (#15932)

* ggml : remove adding extra dim timestep embedding

This commit updates the ggml_timestep_embedding function to no longer
add an extra dimension when the specified dimension is odd.

The motivation for this change is that this introduces an unnecessary
dimension when the dimension is odd, which caused an issue in the
kernels which were not expecting this extra dimension and it resulted in
uninitialized memory for the second to last dimension.

* ggml-cuda : fix padding in timestep embedding kernel

This commit removes the zeroing out of the last dimension now that we
are not adding the extra padding dimension.

* ggml-metal : fix padding in timestep embedding kernel

This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel

* ggml-opencl : fix padding in timestep embedding kernel

This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel.

* ggml-sycl : fix padding in timestep embedding kernel

This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel.

* ggml-vulkan : fix padding in timestep embedding kernel

This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel.

* ggml-cpu : fix padding in timestep embedding function

This commit removes the zeroing out of the last dimension now that we
are not adding the extra padding dimension.
This commit is contained in:
Daniel Bevenius
2025-09-16 15:25:57 +02:00
committed by GitHub
parent 76888d202e
commit 3913f8730e
7 changed files with 15 additions and 18 deletions

View File

@@ -8599,7 +8599,6 @@ static void ggml_compute_forward_timestep_embedding_f32(
}
if (dim % 2 != 0 && ith == 0) {
embed_data[2 * half] = 0.f;
embed_data[dim] = 0.f;
}
}
}

View File

@@ -7,11 +7,11 @@ static __global__ void timestep_embedding_f32(const float * timesteps, float * d
int j = threadIdx.x + blockIdx.x * blockDim.x;
float * embed_data = (float *)((char *)dst + i*nb1);
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
embed_data[dim] = 0.f;
int half = dim / 2;
if (dim % 2 != 0 && j == half) {
embed_data[2 * half] = 0.f;
}
int half = dim / 2;
if (j >= half) {
return;
}

View File

@@ -4167,7 +4167,7 @@ kernel void kernel_timestep_embedding_f32(
}
if (args.dim % 2 != 0 && tpitg.x == 0) {
embed_data[args.dim] = 0.f;
embed_data[2 * half_] = 0.f;
}
}

View File

@@ -26,8 +26,8 @@ kernel void kernel_timestep_embedding(
local_half_dim = logical_dim / 2;
local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes);
if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) {
local_embed_data_ptr[logical_dim] = 0.0f;
if (logical_dim % 2 != 0 && local_j == local_half_dim) {
local_embed_data_ptr[2 * local_half_dim] = 0.0f;
}
if (local_j >= local_half_dim) {

View File

@@ -21,11 +21,12 @@ static void timestep_embedding_f32(
int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
float * embed_data = (float *)((char *)dst + i*nb1);
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
embed_data[dim] = 0.f;
int half = dim / 2;
if (dim % 2 != 0 && j == half) {
embed_data[2 * half] = 0.f;
}
int half = dim / 2;
if (j >= half) {
return;
}

View File

@@ -24,11 +24,12 @@ void main() {
const uint j = gl_GlobalInvocationID.x;
const uint d_offset = i * p.nb1;
if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) {
data_d[d_offset + p.dim] = 0.f;
const uint half_dim = p.dim / 2;
if (p.dim % 2 != 0 && j == half_dim) {
data_d[d_offset + 2 * half_dim] = 0.f;
}
const uint half_dim = p.dim / 2;
if (j >= half_dim) {
return;
}

View File

@@ -4923,12 +4923,8 @@ struct ggml_tensor * ggml_timestep_embedding(
struct ggml_tensor * timesteps,
int dim,
int max_period) {
int actual_dim = dim;
if (dim % 2 != 0) {
actual_dim = dim + 1;
}
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps->ne[0]);
ggml_set_op_params_i32(result, 0, dim);
ggml_set_op_params_i32(result, 1, max_period);