llama : add gpt-oss (#15091)

* oai moe

* compat with new checkpoint

* add attn sink impl

* add rope scaling yarn

* logits match with latest transformers code

* wip chat template

* rm trailing space

* use ggml_scale_bias

* rm redundant is_swa_all

* convert interleaved gate_up

* graph : fix activation function to match reference (#7)

* vocab : handle o200k_harmony special tokens

* ggml : add attention sinks support (#1)

* llama : add attn sinks

* ggml : add attn sinks

* cuda : add attn sinks

* vulkan : add support for sinks in softmax

remove unnecessary return

* ggml : add fused swiglu_oai op (#11)

* ggml : add fused swiglu_oai op

* Update ggml/src/ggml-cpu/ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* update CUDA impl

* cont : metal impl

* add vulkan impl

* test-backend-ops : more test cases, clean up

* llama : remove unfused impl

* remove extra lines

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: slaren <slarengh@gmail.com>

* repack mxfp4 upon conversion

* clean up a bit

* enable thinking

* add quick hack to render only some special tokens

* fix bf16 conversion

* remove vocab hack

* webui ok

* support chat parsing for gpt-oss

* fix webui

* direct mapping mxfp4, FINALLY

* force using mxfp4

* properly use lazy tensor

* ggml : add mxfp4

ggml : use e8m0 conversion instead of powf

Co-authored-by: Diego Devesa <slarengh@gmail.com>

change kvalues_mxfp4 table to match e2m1 (#6)

metal : remove quantization for now (not used)

cuda : fix disabled CUDA graphs due to ffn moe bias

vulkan : add support for mxfp4

cont : add cm2 dequant

* ggml : add ggml_add_id (#13)

* ggml : add ggml_add_id

* add cuda impl

* llama : add weight support check for add_id

* perf opt

* add vulkan impl

* rename cuda files

* add metal impl

* allow in-place ggml_add_id

* llama : keep biases on CPU with --cpu-moe

* llama : fix compile error

ggml-ci

* cuda : add fallback for __nv_cvt_e8m0_to_bf16raw

ggml-ci

* cleanup

ggml-ci

* sycl : fix supports_op for MXFP4

ggml-ci

* fix Unknown reasoning format

* ggml-cpu : fix AVX build

ggml-ci

* fix hip build

ggml-ci

* cuda : add mxfp4 dequantization support for cuBLAS

ggml-ci

* ggml-cpu : fix mxfp4 fallback definitions for some architectures

ggml-ci

* cuda : fix version required for __nv_cvt_e8m0_to_bf16raw

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov
2025-08-05 22:10:36 +03:00
committed by GitHub
parent f324a3b715
commit fd1234cb46
83 changed files with 2942 additions and 227 deletions

View File

@@ -1,8 +1,20 @@
#pragma once
#include "common.cuh"
#include <cstdint>
static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
const uint8_t * x8 = (const uint8_t *) x;
int x32 = x8[4*i32 + 0] << 0;
x32 |= x8[4*i32 + 1] << 8;
x32 |= x8[4*i32 + 2] << 16;
x32 |= x8[4*i32 + 3] << 24;
return x32;
}
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
@@ -16,6 +28,20 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
const int8_t * q1_8 = (const int8_t *) &q1_32;
const char4 val1_8 = make_char4(
table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
}
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
@@ -211,6 +237,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
return d8_1*sumf;
}
#define VDR_MXFP4_Q8_1_MMVQ 2
#define VDR_MXFP4_Q8_1_MMQ 4
static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
const int * q8 = (const int *) bq8_1->qs + iqs;
int sumi = 0;
#pragma unroll
for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
}
const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
return d * sumi;
}
#define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 4
@@ -1068,20 +1118,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
}
static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
const int8_t * q0_8 = (const int8_t *) &q0_32;
const char4 val0_8 = make_char4(
kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
const int8_t * q1_8 = (const int8_t *) &q1_32;
const char4 val1_8 = make_char4(
kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
}
#define VDR_IQ4_NL_Q8_1_MMVQ 2
#define VDR_IQ4_NL_Q8_1_MMQ 4
@@ -1096,7 +1132,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
const int2 v = get_int_from_table_16(aux_q4);
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
@@ -1118,7 +1154,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
const int2 v = get_int_from_table_16(aux_q4);
const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);