diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d515ec0a32..25d43122e9 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -3022,6 +3022,13 @@ static bool ggml_metal_encode_node( const int64_t shmem_size = d_state / 32; GGML_ASSERT(shmem_size * 32 == d_state); + // The final simd_sum won't work if the number of simd groups is + // larger than the size of a single simd group. If this case is + // hit at some point, the logic in the second simd_sum could be + // expanded to handle this with one more sequential simd_sum to + // collapse simd group sums another time. + GGML_ASSERT(shmem_size <= 32); + // One thread pre element in d_state GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);