mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
CUDA: General GEMV fusion (#16715)
This commit is contained in:
@@ -1005,3 +1005,16 @@ struct ggml_backend_cuda_context {
|
|||||||
return pool(device);
|
return pool(device);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ggml_cuda_mm_fusion_args_host {
|
||||||
|
const ggml_tensor * x_bias = nullptr;
|
||||||
|
const ggml_tensor * gate = nullptr;
|
||||||
|
const ggml_tensor * gate_bias = nullptr;
|
||||||
|
ggml_glu_op glu_op;
|
||||||
|
};
|
||||||
|
struct ggml_cuda_mm_fusion_args_device {
|
||||||
|
const void * x_bias = nullptr;
|
||||||
|
const void * gate = nullptr;
|
||||||
|
const void * gate_bias = nullptr;
|
||||||
|
ggml_glu_op glu_op;
|
||||||
|
};
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#pragma once
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||||
|
|||||||
@@ -2007,6 +2007,147 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
|
||||||
|
const ggml_tensor * ffn_gate,
|
||||||
|
const ggml_tensor * glu,
|
||||||
|
const ggml_tensor * ffn_up_bias = nullptr,
|
||||||
|
const ggml_tensor * ffn_gate_bias = nullptr) {
|
||||||
|
const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
|
||||||
|
|
||||||
|
if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
|
||||||
|
const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
|
||||||
|
|
||||||
|
GGML_ASSERT(ffn_up && ffn_gate && glu);
|
||||||
|
|
||||||
|
if (!is_mul_mat && !is_mul_mat_id) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||||
|
|
||||||
|
if (has_bias) {
|
||||||
|
if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (expected_bias_op == GGML_OP_ADD) {
|
||||||
|
const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
|
||||||
|
const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
|
||||||
|
if (!up_has_mul || !gate_has_mul) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else { // GGML_OP_ADD_ID
|
||||||
|
if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
|
||||||
|
!ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ffn_up->src[1] != ffn_gate->src[1]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
|
||||||
|
|
||||||
|
if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
|
||||||
|
ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
|
||||||
|
|
||||||
|
//TODO: add support for fusion for split buffers
|
||||||
|
if (split) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
|
||||||
|
ggml_tensor * src0 = tensor->src[0];
|
||||||
|
ggml_tensor * src1 = tensor->src[1];
|
||||||
|
const ggml_tensor * dst = tensor;
|
||||||
|
|
||||||
|
const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
|
||||||
|
|
||||||
|
bool use_mul_mat_vec_f =
|
||||||
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
|
||||||
|
src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||||
|
|
||||||
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
|
use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
|
||||||
|
|
||||||
|
//we only support fusion for ncols_dst = 1
|
||||||
|
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return use_mul_mat_vec_f;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
|
||||||
|
ggml_tensor * src0 = tensor->src[0];
|
||||||
|
ggml_tensor * src1 = tensor->src[1];
|
||||||
|
const ggml_tensor * dst = tensor;
|
||||||
|
|
||||||
|
const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
|
||||||
|
ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
|
||||||
|
src0->view_src;
|
||||||
|
|
||||||
|
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
|
||||||
|
dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||||
|
|
||||||
|
// fusion is not universally faster on Pascal
|
||||||
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
|
if (cc <= GGML_CUDA_CC_PASCAL) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
//we only support fusion for ncols_dst = 1
|
||||||
|
if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return use_mul_mat_vec_q;
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
|
||||||
|
|
||||||
@@ -2745,7 +2886,7 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (node->op == GGML_OP_SCALE &&
|
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
|
||||||
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -2854,6 +2995,38 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
|
||||||
|
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
|
||||||
|
|
||||||
|
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
|
||||||
|
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
|
||||||
|
|
||||||
|
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
|
||||||
|
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
|
||||||
|
|
||||||
|
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||||
|
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
|
||||||
|
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
|
||||||
|
const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3];
|
||||||
|
const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
|
||||||
|
|
||||||
|
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
|
||||||
|
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
|
||||||
|
|
||||||
|
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
|
||||||
|
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
|
||||||
|
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
|
||||||
|
|
||||||
|
if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -3004,6 +3177,184 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool fused_mul_mat_vec = false;
|
||||||
|
int fused_node_count = 0;
|
||||||
|
|
||||||
|
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||||
|
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||||
|
|
||||||
|
if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
|
||||||
|
ggml_tensor * glu = cgraph->nodes[i + 4];
|
||||||
|
ggml_tensor * gate_bias_n = glu->src[0];
|
||||||
|
ggml_tensor * up_bias_n = glu->src[1];
|
||||||
|
|
||||||
|
//we don't assume the order for {gate, up}. Instead infer it from the bias tensor
|
||||||
|
ggml_tensor * gate_n = nullptr;
|
||||||
|
ggml_tensor * up_n = nullptr;
|
||||||
|
|
||||||
|
if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
|
||||||
|
gate_n = cgraph->nodes[i];
|
||||||
|
up_n = cgraph->nodes[i + 2];
|
||||||
|
} else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
|
||||||
|
gate_n = cgraph->nodes[i + 2];
|
||||||
|
up_n = cgraph->nodes[i];
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
|
||||||
|
if (op_bias == GGML_OP_ADD) {
|
||||||
|
if (bias_node->src[0] == mul_node) {
|
||||||
|
return bias_node->src[1];
|
||||||
|
}
|
||||||
|
if (bias_node->src[1] == mul_node) {
|
||||||
|
return bias_node->src[0];
|
||||||
|
}
|
||||||
|
return (ggml_tensor *) nullptr;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
|
||||||
|
GGML_ASSERT(bias_node->src[0] == mul_node);
|
||||||
|
return bias_node->src[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
|
||||||
|
ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
|
||||||
|
|
||||||
|
if (!up_bias_tensor || !gate_bias_tensor) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = up_n->src[0];
|
||||||
|
const ggml_tensor * src1 = up_n->src[1];
|
||||||
|
const ggml_tensor * ids = up_n->src[2];
|
||||||
|
|
||||||
|
if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
|
||||||
|
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||||
|
fusion_data.gate = gate_n->src[0];
|
||||||
|
fusion_data.x_bias = up_bias_tensor;
|
||||||
|
fusion_data.gate_bias = gate_bias_tensor;
|
||||||
|
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||||
|
|
||||||
|
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||||
|
fused_mul_mat_vec = true;
|
||||||
|
fused_node_count = 5;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
|
||||||
|
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||||
|
fusion_data.gate = gate_n->src[0];
|
||||||
|
fusion_data.x_bias = up_bias_tensor;
|
||||||
|
fusion_data.gate_bias = gate_bias_tensor;
|
||||||
|
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||||
|
|
||||||
|
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||||
|
fused_mul_mat_vec = true;
|
||||||
|
fused_node_count = 5;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
|
||||||
|
ggml_tensor * glu = cgraph->nodes[i + 2];
|
||||||
|
ggml_tensor * gate = glu->src[0];
|
||||||
|
ggml_tensor * up = glu->src[1];
|
||||||
|
|
||||||
|
bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
|
||||||
|
|| (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
|
||||||
|
|
||||||
|
if (!ok) continue;
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = up->src[0];
|
||||||
|
const ggml_tensor * src1 = up->src[1];
|
||||||
|
const ggml_tensor * ids = up->src[2];
|
||||||
|
|
||||||
|
if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
|
||||||
|
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||||
|
fusion_data.gate = gate->src[0];
|
||||||
|
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||||
|
|
||||||
|
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||||
|
fused_mul_mat_vec = true;
|
||||||
|
fused_node_count = 3;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
|
||||||
|
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||||
|
fusion_data.gate = gate->src[0];
|
||||||
|
fusion_data.glu_op = ggml_get_glu_op(glu);
|
||||||
|
|
||||||
|
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
|
||||||
|
fused_mul_mat_vec = true;
|
||||||
|
fused_node_count = 3;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fused_mul_mat_vec) {
|
||||||
|
i += fused_node_count - 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
fused_mul_mat_vec = false;
|
||||||
|
fused_node_count = 0;
|
||||||
|
|
||||||
|
for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
|
||||||
|
const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
|
||||||
|
|
||||||
|
if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * mm_node = cgraph->nodes[i];
|
||||||
|
ggml_tensor * bias_node = cgraph->nodes[i + 1];
|
||||||
|
|
||||||
|
ggml_tensor * bias_tensor = nullptr;
|
||||||
|
if (bias_op == GGML_OP_ADD) {
|
||||||
|
if (bias_node->src[0] == mm_node) {
|
||||||
|
bias_tensor = bias_node->src[1];
|
||||||
|
} else if (bias_node->src[1] == mm_node) {
|
||||||
|
bias_tensor = bias_node->src[0];
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (bias_node->src[0] != mm_node) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
bias_tensor = bias_node->src[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = mm_node->src[0];
|
||||||
|
const ggml_tensor * src1 = mm_node->src[1];
|
||||||
|
const ggml_tensor * ids = mm_node->src[2];
|
||||||
|
|
||||||
|
if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_cuda_mm_fusion_args_host fusion_data{};
|
||||||
|
fusion_data.x_bias = bias_tensor;
|
||||||
|
|
||||||
|
if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
|
||||||
|
ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||||
|
fused_mul_mat_vec = true;
|
||||||
|
fused_node_count = 2;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
|
||||||
|
ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
|
||||||
|
fused_mul_mat_vec = true;
|
||||||
|
fused_node_count = 2;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fused_mul_mat_vec) {
|
||||||
|
i += fused_node_count - 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
|
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
|
||||||
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "convert.cuh"
|
#include "unary.cuh"
|
||||||
#include "mmvf.cuh"
|
#include "mmvf.cuh"
|
||||||
|
#include "convert.cuh"
|
||||||
|
|
||||||
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
|
||||||
static __global__ void mul_mat_vec_f(
|
static __global__ void mul_mat_vec_f(
|
||||||
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
||||||
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
||||||
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||||
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||||
@@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f(
|
|||||||
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
|
||||||
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||||
|
|
||||||
|
bool use_gate = false;
|
||||||
|
bool use_bias = false;
|
||||||
|
bool use_gate_bias = false;
|
||||||
|
ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
|
||||||
|
const T * gate_x = nullptr;
|
||||||
|
const float * x_bias = nullptr;
|
||||||
|
const float * gate_bias = nullptr;
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
use_gate = fusion.gate != nullptr;
|
||||||
|
use_bias = fusion.x_bias != nullptr;
|
||||||
|
use_gate_bias = fusion.gate_bias != nullptr;
|
||||||
|
glu_op = fusion.glu_op;
|
||||||
|
|
||||||
|
if (use_gate) {
|
||||||
|
gate_x = static_cast<const T *>(fusion.gate);
|
||||||
|
}
|
||||||
|
if (use_bias) {
|
||||||
|
x_bias = static_cast<const float *>(fusion.x_bias);
|
||||||
|
}
|
||||||
|
if (use_gate_bias) {
|
||||||
|
gate_bias = static_cast<const float *>(fusion.gate_bias);
|
||||||
|
use_gate_bias = use_gate;
|
||||||
|
} else {
|
||||||
|
use_gate_bias = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_gate) {
|
||||||
|
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||||
|
}
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
const int channel_bias = ids ? channel_x : channel_dst;
|
||||||
|
if (use_bias) {
|
||||||
|
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
|
||||||
|
}
|
||||||
|
if (use_gate_bias) {
|
||||||
|
gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const float2 * y2 = (const float2 *) y;
|
const float2 * y2 = (const float2 *) y;
|
||||||
|
|
||||||
extern __shared__ char data_mmv[];
|
extern __shared__ char data_mmv[];
|
||||||
float * buf_iw = (float *) data_mmv;
|
float * buf_iw = (float *) data_mmv;
|
||||||
|
float * buf_iw_gate = nullptr;
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
if (block_size > warp_size) {
|
if (block_size > warp_size) {
|
||||||
if (tid < warp_size) {
|
if (tid < warp_size) {
|
||||||
buf_iw[tid] = 0.0f;
|
buf_iw[tid] = 0.0f;
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
buf_iw_gate[tid] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
float sumf[ncols_dst] = {0.0f};
|
float sumf[ncols_dst] = {0.0f};
|
||||||
|
float sumf_gate[ncols_dst];
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
|
sumf_gate[j] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if constexpr (std::is_same_v<T, float>) {
|
if constexpr (std::is_same_v<T, float>) {
|
||||||
const float2 * x2 = (const float2 *) x;
|
const float2 * x2 = (const float2 *) x;
|
||||||
|
const float2 * gate_x2 = nullptr;
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
gate_x2 = (const float2 *) gate_x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
const float2 tmpx = x2[col2];
|
const float2 tmpx = x2[col2];
|
||||||
|
float2 tmpx_gate = make_float2(0.0f, 0.0f);
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
tmpx_gate = gate_x2[col2];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||||
|
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if constexpr (std::is_same_v<T, half>) {
|
} else if constexpr (std::is_same_v<T, half>) {
|
||||||
const half2 * x2 = (const half2 *) x;
|
const half2 * x2 = (const half2 *) x;
|
||||||
|
const half2 * gate_x2 = nullptr;
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
gate_x2 = (const half2 *) gate_x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (std::is_same_v<type_acc, float>) {
|
if (std::is_same_v<type_acc, float>) {
|
||||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
const float2 tmpx = __half22float2(x2[col2]);
|
const float2 tmpx = __half22float2(x2[col2]);
|
||||||
|
float2 tmpx_gate = make_float2(0.0f, 0.0f);
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
tmpx_gate = __half22float2(gate_x2[col2]);
|
||||||
|
}
|
||||||
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||||
|
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
#ifdef FP16_AVAILABLE
|
#ifdef FP16_AVAILABLE
|
||||||
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
|
||||||
|
half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
const half2 tmpx = x2[col2];
|
const half2 tmpx = x2[col2];
|
||||||
|
half2 tmpx_gate = make_half2(0.0f, 0.0f);
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
tmpx_gate = gate_x2[col2];
|
||||||
|
}
|
||||||
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,6 +190,15 @@ static __global__ void mul_mat_vec_f(
|
|||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
|
sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
@@ -91,8 +207,20 @@ static __global__ void mul_mat_vec_f(
|
|||||||
//TODO: add support for ggml_cuda_mad for hip_bfloat162
|
//TODO: add support for ggml_cuda_mad for hip_bfloat162
|
||||||
#if defined(GGML_USE_HIP)
|
#if defined(GGML_USE_HIP)
|
||||||
const int * x2 = (const int *) x;
|
const int * x2 = (const int *) x;
|
||||||
|
const int * gate_x2 = nullptr;
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
gate_x2 = (const int *) gate_x;
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
const int tmpx = x2[col2];
|
const int tmpx = x2[col2];
|
||||||
|
int tmpx_gate = 0;
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
tmpx_gate = gate_x2[col2];
|
||||||
|
}
|
||||||
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
@@ -100,17 +228,45 @@ static __global__ void mul_mat_vec_f(
|
|||||||
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
|
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
|
||||||
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
|
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
|
||||||
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
|
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
|
||||||
|
const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
|
||||||
|
ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
|
||||||
|
ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
|
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
|
||||||
|
const nv_bfloat162 * gate_x2 = nullptr;
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
gate_x2 = (const nv_bfloat162 *) gate_x;
|
||||||
|
}
|
||||||
|
}
|
||||||
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||||
const nv_bfloat162 tmpx = x2[col2];
|
const nv_bfloat162 tmpx = x2[col2];
|
||||||
|
nv_bfloat162 tmpx_gate;
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
tmpx_gate = gate_x2[col2];
|
||||||
|
}
|
||||||
|
}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
const float2 tmpy = y2[j*stride_col_y2 + col2];
|
||||||
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
|
||||||
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
|
||||||
|
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -122,13 +278,31 @@ static __global__ void mul_mat_vec_f(
|
|||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (block_size > warp_size) {
|
if (block_size > warp_size) {
|
||||||
buf_iw[tid/warp_size] = sumf[j];
|
buf_iw[tid/warp_size] = sumf[j];
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
buf_iw_gate[tid/warp_size] = sumf_gate[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (tid < warp_size) {
|
if (tid < warp_size) {
|
||||||
sumf[j] = buf_iw[tid];
|
sumf[j] = buf_iw[tid];
|
||||||
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
sumf_gate[j] = buf_iw_gate[tid];
|
||||||
|
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (j < ncols_dst) {
|
if (j < ncols_dst) {
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
@@ -139,12 +313,70 @@ static __global__ void mul_mat_vec_f(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[tid*stride_col_dst + row] = sumf[tid];
|
float value = sumf[tid];
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_bias) {
|
||||||
|
value += x_bias[tid*stride_col_dst + row];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (use_gate) {
|
||||||
|
float gate_value = sumf_gate[tid];
|
||||||
|
if (use_gate_bias) {
|
||||||
|
gate_value += gate_bias[tid*stride_col_dst + row];
|
||||||
|
}
|
||||||
|
switch (glu_op) {
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
value *= ggml_cuda_op_silu_single(gate_value);
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
value *= ggml_cuda_op_gelu_single(gate_value);
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI: {
|
||||||
|
value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[tid*stride_col_dst + row] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, typename type_acc, int ncols_dst, int block_size>
|
||||||
|
static void mul_mat_vec_f_switch_fusion(
|
||||||
|
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||||
|
const int64_t ncols, const int64_t nrows,
|
||||||
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||||
|
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||||
|
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||||
|
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
|
||||||
|
|
||||||
|
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||||
|
if constexpr (ncols_dst == 1) {
|
||||||
|
if (has_fusion) {
|
||||||
|
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
|
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||||
|
|
||||||
|
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
|
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, typename type_acc, int ncols_dst>
|
template <typename T, typename type_acc, int ncols_dst>
|
||||||
static void launch_mul_mat_vec_f_cuda(
|
void launch_mul_mat_vec_f_cuda(
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||||
const int64_t ncols, const int64_t nrows,
|
const int64_t ncols, const int64_t nrows,
|
||||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||||
@@ -176,57 +408,59 @@ static void launch_mul_mat_vec_f_cuda(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const int nbytes_shared = warp_size*sizeof(float);
|
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||||
|
|
||||||
|
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
|
||||||
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
||||||
const dim3 block_dims(block_size_best, 1, 1);
|
const dim3 block_dims(block_size_best, 1, 1);
|
||||||
switch (block_size_best) {
|
switch (block_size_best) {
|
||||||
case 32: {
|
case 32: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||||
} break;
|
} break;
|
||||||
case 64: {
|
case 64: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||||
} break;
|
} break;
|
||||||
case 96: {
|
case 96: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||||
} break;
|
} break;
|
||||||
case 160: {
|
case 160: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||||
} break;
|
} break;
|
||||||
case 192: {
|
case 192: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||||
} break;
|
} break;
|
||||||
case 224: {
|
case 224: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||||
} break;
|
} break;
|
||||||
case 256: {
|
case 256: {
|
||||||
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
|
||||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
@@ -236,7 +470,7 @@ static void launch_mul_mat_vec_f_cuda(
|
|||||||
|
|
||||||
template <typename T, typename type_acc>
|
template <typename T, typename type_acc>
|
||||||
static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||||
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||||
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
||||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||||
@@ -246,49 +480,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
|||||||
switch (ncols_dst) {
|
switch (ncols_dst) {
|
||||||
case 1:
|
case 1:
|
||||||
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
|
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
|
||||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
|
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
|
||||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
|
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
|
||||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
|
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
|
||||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
break;
|
break;
|
||||||
case 5:
|
case 5:
|
||||||
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
|
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
|
||||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
break;
|
break;
|
||||||
case 6:
|
case 6:
|
||||||
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
|
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
|
||||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
break;
|
break;
|
||||||
case 7:
|
case 7:
|
||||||
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
|
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
|
||||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
|
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
|
||||||
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
break;
|
break;
|
||||||
@@ -300,29 +534,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
|||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void mul_mat_vec_f_cuda(
|
static void mul_mat_vec_f_cuda(
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||||
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
||||||
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
||||||
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
|
||||||
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
||||||
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
||||||
enum ggml_prec prec, cudaStream_t stream) {
|
enum ggml_prec prec, cudaStream_t stream) {
|
||||||
|
|
||||||
if constexpr(std::is_same_v<T, half>) {
|
if constexpr(std::is_same_v<T, half>) {
|
||||||
if (prec == GGML_PREC_DEFAULT) {
|
if (prec == GGML_PREC_DEFAULT) {
|
||||||
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
|
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
|
||||||
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
|
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
|
||||||
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
||||||
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||||
|
const ggml_cuda_mm_fusion_args_host * fusion) {
|
||||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
@@ -348,6 +584,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||||||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||||
float * dst_d = (float *) dst->data;
|
float * dst_d = (float *) dst->data;
|
||||||
|
|
||||||
|
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||||
|
|
||||||
|
if (fusion) {
|
||||||
|
GGML_ASSERT( !ids || dst->ne[2] == 1);
|
||||||
|
GGML_ASSERT( ids || dst->ne[1] == 1);
|
||||||
|
if (fusion->x_bias) {
|
||||||
|
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
|
||||||
|
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
|
||||||
|
fusion_local.x_bias = fusion->x_bias->data;
|
||||||
|
}
|
||||||
|
if (fusion->gate) {
|
||||||
|
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
|
||||||
|
fusion_local.gate = fusion->gate->data;
|
||||||
|
}
|
||||||
|
if (fusion->gate_bias) {
|
||||||
|
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
|
||||||
|
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
|
||||||
|
fusion_local.gate_bias = fusion->gate_bias->data;
|
||||||
|
}
|
||||||
|
fusion_local.glu_op = fusion->glu_op;
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t s01 = src0->nb[1] / ts_src0;
|
const int64_t s01 = src0->nb[1] / ts_src0;
|
||||||
const int64_t s11 = src1->nb[1] / ts_src1;
|
const int64_t s11 = src1->nb[1] / ts_src1;
|
||||||
const int64_t s1 = dst->nb[1] / ts_dst;
|
const int64_t s1 = dst->nb[1] / ts_dst;
|
||||||
@@ -370,19 +630,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: {
|
case GGML_TYPE_F32: {
|
||||||
const float * src0_d = (const float *) src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16: {
|
case GGML_TYPE_F16: {
|
||||||
const half * src0_d = (const half *) src0->data;
|
const half * src0_d = (const half *) src0->data;
|
||||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_BF16: {
|
case GGML_TYPE_BF16: {
|
||||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
||||||
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
||||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
||||||
} break;
|
} break;
|
||||||
@@ -409,7 +669,6 @@ void ggml_cuda_op_mul_mat_vec_f(
|
|||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
|
||||||
|
|
||||||
|
|
||||||
// ggml_cuda_op provides single, contiguous matrices
|
// ggml_cuda_op provides single, contiguous matrices
|
||||||
const int64_t stride_row = ne00;
|
const int64_t stride_row = ne00;
|
||||||
const int64_t stride_col_y = ne10;
|
const int64_t stride_col_y = ne10;
|
||||||
@@ -426,22 +685,23 @@ void ggml_cuda_op_mul_mat_vec_f(
|
|||||||
const int64_t stride_sample_y = 0;
|
const int64_t stride_sample_y = 0;
|
||||||
const int64_t stride_sample_dst = 0;
|
const int64_t stride_sample_dst = 0;
|
||||||
|
|
||||||
|
ggml_cuda_mm_fusion_args_device empty{};
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: {
|
case GGML_TYPE_F32: {
|
||||||
const float * src0_d = (const float *) src0_dd_i;
|
const float * src0_d = (const float *) src0_dd_i;
|
||||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16: {
|
case GGML_TYPE_F16: {
|
||||||
const half * src0_d = (const half *) src0_dd_i;
|
const half * src0_d = (const half *) src0_dd_i;
|
||||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_BF16: {
|
case GGML_TYPE_BF16: {
|
||||||
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
||||||
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
||||||
} break;
|
} break;
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||||
|
const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||||
|
|
||||||
void ggml_cuda_op_mul_mat_vec_f(
|
void ggml_cuda_op_mul_mat_vec_f(
|
||||||
ggml_backend_cuda_context & ctx,
|
ggml_backend_cuda_context & ctx,
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#include "mmvq.cuh"
|
#include "mmvq.cuh"
|
||||||
#include "quantize.cuh"
|
#include "quantize.cuh"
|
||||||
|
#include "unary.cuh"
|
||||||
#include "vecdotq.cuh"
|
#include "vecdotq.cuh"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
@@ -82,7 +83,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
|
|||||||
return MMVQ_PARAMETERS_GENERIC;
|
return MMVQ_PARAMETERS_GENERIC;
|
||||||
}
|
}
|
||||||
|
|
||||||
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
|
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
|
||||||
if (table_id == MMVQ_PARAMETERS_GENERIC) {
|
if (table_id == MMVQ_PARAMETERS_GENERIC) {
|
||||||
switch (ncols_dst) {
|
switch (ncols_dst) {
|
||||||
case 1:
|
case 1:
|
||||||
@@ -136,11 +137,11 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
|
|||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <ggml_type type, int ncols_dst>
|
|
||||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||||
|
template <ggml_type type, int ncols_dst, bool has_fusion>
|
||||||
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
|
||||||
static __global__ void mul_mat_vec_q(
|
static __global__ void mul_mat_vec_q(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
|
||||||
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||||
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
||||||
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
||||||
@@ -169,8 +170,38 @@ static __global__ void mul_mat_vec_q(
|
|||||||
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
|
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
|
||||||
const uint32_t sample_y = sample_dst;
|
const uint32_t sample_y = sample_dst;
|
||||||
|
|
||||||
|
bool use_gate = false;
|
||||||
|
bool use_bias = false;
|
||||||
|
bool use_gate_bias = false;
|
||||||
|
const void * vgate = nullptr;
|
||||||
|
const float * x_bias = nullptr;
|
||||||
|
const float * gate_bias = nullptr;
|
||||||
|
ggml_glu_op active_glu;
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
use_gate = fusion.gate != nullptr;
|
||||||
|
use_bias = fusion.x_bias != nullptr;
|
||||||
|
use_gate_bias = fusion.gate_bias != nullptr && use_gate;
|
||||||
|
vgate = fusion.gate;
|
||||||
|
x_bias = (const float *) fusion.x_bias;
|
||||||
|
gate_bias = (const float *) fusion.gate_bias;
|
||||||
|
active_glu = fusion.glu_op;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t channel_bias = ids ? channel_x : channel_dst;
|
||||||
|
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_bias) {
|
||||||
|
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||||
|
}
|
||||||
|
if (use_gate_bias) {
|
||||||
|
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||||
|
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
|
||||||
|
|
||||||
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
|
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
|
||||||
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
|
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
|
||||||
@@ -187,17 +218,35 @@ static __global__ void mul_mat_vec_q(
|
|||||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||||
tmp[j][i] += vec_dot_q_cuda(
|
tmp[j][i] += vec_dot_q_cuda(
|
||||||
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
tmp_gate[j][i] += vec_dot_q_cuda(
|
||||||
|
vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
||||||
|
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
|
||||||
|
if constexpr (!has_fusion) {
|
||||||
|
(void) tmp_shared_gate;
|
||||||
|
} else if (!use_gate) {
|
||||||
|
(void) tmp_shared_gate;
|
||||||
|
}
|
||||||
|
|
||||||
if (threadIdx.y > 0) {
|
if (threadIdx.y > 0) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_dst; ++j) {
|
for (int j = 0; j < ncols_dst; ++j) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||||
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
|
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -216,12 +265,49 @@ static __global__ void mul_mat_vec_q(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < nwarps-1; ++l) {
|
for (int l = 0; l < nwarps-1; ++l) {
|
||||||
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
|
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
|
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_gate) {
|
||||||
|
tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
|
||||||
dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
|
float result = tmp[j][threadIdx.x];
|
||||||
|
if constexpr (has_fusion) {
|
||||||
|
if (use_bias) {
|
||||||
|
result += x_bias[j*stride_col_dst + threadIdx.x];
|
||||||
|
}
|
||||||
|
if (use_gate) {
|
||||||
|
float gate_value = tmp_gate[j][threadIdx.x];
|
||||||
|
if (use_gate_bias) {
|
||||||
|
gate_value += gate_bias[j*stride_col_dst + threadIdx.x];
|
||||||
|
}
|
||||||
|
switch (active_glu) {
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
result *= ggml_cuda_op_silu_single(gate_value);
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
result *= ggml_cuda_op_gelu_single(gate_value);
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_SWIGLU_OAI: {
|
||||||
|
result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
result = result * gate_value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dst[j*stride_col_dst + threadIdx.x] = result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -235,9 +321,37 @@ static std::pair<dim3, dim3> calc_launch_params(
|
|||||||
return {block_nums, block_dims};
|
return {block_nums, block_dims};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<ggml_type type, int c_ncols_dst>
|
||||||
|
static void mul_mat_vec_q_switch_fusion(
|
||||||
|
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||||
|
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
|
||||||
|
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
|
||||||
|
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
|
||||||
|
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
|
||||||
|
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
|
||||||
|
|
||||||
|
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||||
|
if constexpr (c_ncols_dst == 1) {
|
||||||
|
if (has_fusion) {
|
||||||
|
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
|
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
|
||||||
|
|
||||||
|
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
|
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
|
}
|
||||||
|
|
||||||
template <ggml_type type>
|
template <ggml_type type>
|
||||||
static void mul_mat_vec_q_switch_ncols_dst(
|
static void mul_mat_vec_q_switch_ncols_dst(
|
||||||
const void * vx, const void * vy, const int32_t * ids, float * dst,
|
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_dst,
|
const int ncols_x, const int nrows_x, const int ncols_dst,
|
||||||
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||||
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
||||||
@@ -256,80 +370,83 @@ static void mul_mat_vec_q_switch_ncols_dst(
|
|||||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||||
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
|
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
|
||||||
|
|
||||||
|
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
|
||||||
|
|
||||||
GGML_ASSERT(!ids || ncols_dst == 1);
|
GGML_ASSERT(!ids || ncols_dst == 1);
|
||||||
switch (ncols_dst) {
|
switch (ncols_dst) {
|
||||||
case 1: {
|
case 1: {
|
||||||
constexpr int c_ncols_dst = 1;
|
constexpr int c_ncols_dst = 1;
|
||||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
|
dims.first, dims.second, 0, stream);
|
||||||
} break;
|
} break;
|
||||||
case 2: {
|
case 2: {
|
||||||
constexpr int c_ncols_dst = 2;
|
constexpr int c_ncols_dst = 2;
|
||||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
|
dims.first, dims.second, 0, stream);
|
||||||
} break;
|
} break;
|
||||||
case 3: {
|
case 3: {
|
||||||
constexpr int c_ncols_dst = 3;
|
constexpr int c_ncols_dst = 3;
|
||||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
|
dims.first, dims.second, 0, stream);
|
||||||
} break;
|
} break;
|
||||||
case 4: {
|
case 4: {
|
||||||
constexpr int c_ncols_dst = 4;
|
constexpr int c_ncols_dst = 4;
|
||||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
|
dims.first, dims.second, 0, stream);
|
||||||
} break;
|
} break;
|
||||||
case 5: {
|
case 5: {
|
||||||
constexpr int c_ncols_dst = 5;
|
constexpr int c_ncols_dst = 5;
|
||||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
|
dims.first, dims.second, 0, stream);
|
||||||
} break;
|
} break;
|
||||||
case 6: {
|
case 6: {
|
||||||
constexpr int c_ncols_dst = 6;
|
constexpr int c_ncols_dst = 6;
|
||||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
|
dims.first, dims.second, 0, stream);
|
||||||
} break;
|
} break;
|
||||||
case 7: {
|
case 7: {
|
||||||
constexpr int c_ncols_dst = 7;
|
constexpr int c_ncols_dst = 7;
|
||||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
|
dims.first, dims.second, 0, stream);
|
||||||
} break;
|
} break;
|
||||||
case 8: {
|
case 8: {
|
||||||
constexpr int c_ncols_dst = 8;
|
constexpr int c_ncols_dst = 8;
|
||||||
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
|
||||||
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
|
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
|
|
||||||
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
|
||||||
|
dims.first, dims.second, 0, stream);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
GGML_UNUSED(has_fusion);
|
||||||
|
}
|
||||||
static void mul_mat_vec_q_switch_type(
|
static void mul_mat_vec_q_switch_type(
|
||||||
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
|
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
|
||||||
const int ncols_x, const int nrows_x, const int ncols_dst,
|
const int ncols_x, const int nrows_x, const int ncols_dst,
|
||||||
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
const int stride_row_x, const int stride_col_y, const int stride_col_dst,
|
||||||
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
|
||||||
@@ -339,143 +456,123 @@ static void mul_mat_vec_q_switch_type(
|
|||||||
switch (type_x) {
|
switch (type_x) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_MXFP4:
|
case GGML_TYPE_MXFP4:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ2_XXS:
|
case GGML_TYPE_IQ2_XXS:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ2_S:
|
case GGML_TYPE_IQ2_S:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ1_S:
|
case GGML_TYPE_IQ1_S:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ1_M:
|
case GGML_TYPE_IQ1_M:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ3_S:
|
case GGML_TYPE_IQ3_S:
|
||||||
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
|
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
|
||||||
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
|
||||||
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
||||||
stream);
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
@@ -484,7 +581,8 @@ static void mul_mat_vec_q_switch_type(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_vec_q(
|
void ggml_cuda_mul_mat_vec_q(
|
||||||
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
|
||||||
|
const ggml_cuda_mm_fusion_args_host * fusion) {
|
||||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
|
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
|
||||||
@@ -508,6 +606,31 @@ void ggml_cuda_mul_mat_vec_q(
|
|||||||
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
|
||||||
float * dst_d = (float *) dst->data;
|
float * dst_d = (float *) dst->data;
|
||||||
|
|
||||||
|
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||||
|
|
||||||
|
if (fusion) {
|
||||||
|
GGML_ASSERT( !ids || dst->ne[2] == 1);
|
||||||
|
GGML_ASSERT( ids || dst->ne[1] == 1);
|
||||||
|
|
||||||
|
if (fusion->x_bias) {
|
||||||
|
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
|
||||||
|
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
|
||||||
|
fusion_local.x_bias = fusion->x_bias->data;
|
||||||
|
}
|
||||||
|
if (fusion->gate) {
|
||||||
|
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
|
||||||
|
fusion_local.gate = fusion->gate->data;
|
||||||
|
}
|
||||||
|
if (fusion->gate_bias) {
|
||||||
|
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
|
||||||
|
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
|
||||||
|
fusion_local.gate_bias = fusion->gate_bias->data;
|
||||||
|
}
|
||||||
|
fusion_local.glu_op = fusion->glu_op;
|
||||||
|
}
|
||||||
|
|
||||||
// If src0 is a temporary compute buffer, clear any potential padding.
|
// If src0 is a temporary compute buffer, clear any potential padding.
|
||||||
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
|
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
|
||||||
const size_t size_data = ggml_nbytes(src0);
|
const size_t size_data = ggml_nbytes(src0);
|
||||||
@@ -549,10 +672,10 @@ void ggml_cuda_mul_mat_vec_q(
|
|||||||
const int64_t stride_channel_y = ids ? s11 : s12;
|
const int64_t stride_channel_y = ids ? s11 : s12;
|
||||||
|
|
||||||
mul_mat_vec_q_switch_type(
|
mul_mat_vec_q_switch_type(
|
||||||
src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
|
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
|
||||||
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
|
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
|
||||||
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
||||||
ne03, ne3, s03, s13, s3, stream);
|
ne03, ne3, s03, s13, s3, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_mul_mat_vec_q(
|
void ggml_cuda_op_mul_mat_vec_q(
|
||||||
@@ -578,8 +701,9 @@ void ggml_cuda_op_mul_mat_vec_q(
|
|||||||
const int stride_row_x = ne00 / ggml_blck_size(src0->type);
|
const int stride_row_x = ne00 / ggml_blck_size(src0->type);
|
||||||
const int stride_col_y = src1_padded_row_size / QK8_1;
|
const int stride_col_y = src1_padded_row_size / QK8_1;
|
||||||
|
|
||||||
|
ggml_cuda_mm_fusion_args_device fusion_local{};
|
||||||
mul_mat_vec_q_switch_type(
|
mul_mat_vec_q_switch_type(
|
||||||
src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
|
||||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
|
||||||
|
|
||||||
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
|
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
|
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
|
||||||
|
|
||||||
void ggml_cuda_op_mul_mat_vec_q(
|
void ggml_cuda_op_mul_mat_vec_q(
|
||||||
ggml_backend_cuda_context & ctx,
|
ggml_backend_cuda_context & ctx,
|
||||||
|
|||||||
@@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_gelu(float x) {
|
static __device__ __forceinline__ float op_gelu(float x) {
|
||||||
const float GELU_COEF_A = 0.044715f;
|
return ggml_cuda_op_gelu_single(x);
|
||||||
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
|
||||||
|
|
||||||
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_gelu_erf(float x) {
|
static __device__ __forceinline__ float op_gelu_erf(float x) {
|
||||||
@@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_silu(float x) {
|
static __device__ __forceinline__ float op_silu(float x) {
|
||||||
return x / (1.0f + expf(-x));
|
return ggml_cuda_op_silu_single(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_tanh(float x) {
|
static __device__ __forceinline__ float op_tanh(float x) {
|
||||||
@@ -317,13 +314,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons
|
|||||||
|
|
||||||
float xi = x[j0];
|
float xi = x[j0];
|
||||||
float gi = g[j1];
|
float gi = g[j1];
|
||||||
xi = fminf(xi, limit);
|
|
||||||
gi = fmaxf(fminf(gi, limit), -limit);
|
|
||||||
|
|
||||||
float out_glu = xi / (1.0f + expf(-xi * alpha));
|
dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit);
|
||||||
out_glu = out_glu * (1.0f + gi);
|
|
||||||
|
|
||||||
dst[i] = out_glu;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
#pragma once
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
|
||||||
#define CUDA_NEG_BLOCK_SIZE 256
|
#define CUDA_NEG_BLOCK_SIZE 256
|
||||||
@@ -75,3 +76,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|||||||
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
|
||||||
|
return x / (1.0f + expf(-x));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) {
|
||||||
|
const float GELU_COEF_A = 0.044715f;
|
||||||
|
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||||
|
|
||||||
|
return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x)));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
|
||||||
|
x = fminf(x, limit);
|
||||||
|
g = fmaxf(fminf(g, limit), -limit);
|
||||||
|
|
||||||
|
float out_glu = x / (1.0f + expf(-x * alpha));
|
||||||
|
out_glu = out_glu * (1.0f + g);
|
||||||
|
return out_glu;
|
||||||
|
}
|
||||||
|
|||||||
@@ -810,6 +810,9 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//expand here so that we can fuse ffn gate
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
if (gate && type_gate == LLM_FFN_PAR) {
|
if (gate && type_gate == LLM_FFN_PAR) {
|
||||||
cur = ggml_mul(ctx0, cur, tmp);
|
cur = ggml_mul(ctx0, cur, tmp);
|
||||||
cb(cur, "ffn_gate_par", il);
|
cb(cur, "ffn_gate_par", il);
|
||||||
@@ -1091,6 +1094,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//expand here so that we can fuse ffn gate
|
||||||
|
ggml_build_forward_expand(gf, cur);
|
||||||
|
|
||||||
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||||
cb(experts, "ffn_moe_down", il);
|
cb(experts, "ffn_moe_down", il);
|
||||||
|
|
||||||
|
|||||||
@@ -4721,6 +4721,140 @@ struct test_topk_moe: public test_case {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct test_mul_mat_vec_fusion : public test_case {
|
||||||
|
const ggml_type type;
|
||||||
|
const ggml_glu_op glu_op;
|
||||||
|
const int64_t m;
|
||||||
|
const int64_t n;
|
||||||
|
const int64_t k;
|
||||||
|
const bool use_id;
|
||||||
|
const int n_mats;
|
||||||
|
const int n_used;
|
||||||
|
const bool b; // broadcast b matrix (only for use_id)
|
||||||
|
const bool with_bias;
|
||||||
|
const bool with_gate;
|
||||||
|
|
||||||
|
test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,
|
||||||
|
bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true)
|
||||||
|
: type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) {
|
||||||
|
if (use_id) {
|
||||||
|
GGML_ASSERT(n_used <= n_mats);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string op_desc(ggml_tensor * t) override {
|
||||||
|
GGML_UNUSED(t);
|
||||||
|
return "MUL_MAT_VEC_FUSION";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool run_whole_graph() override { return true; }
|
||||||
|
|
||||||
|
ggml_tensor * build_gate(ggml_context * ctx, ggml_tensor * ffn_gate, ggml_tensor * ffn_up) {
|
||||||
|
ggml_tensor * out = nullptr;
|
||||||
|
if (with_gate) {
|
||||||
|
if (glu_op == GGML_GLU_OP_SWIGLU_OAI) {
|
||||||
|
constexpr float alpha = 1.702f;
|
||||||
|
constexpr float limit = 7.0f;
|
||||||
|
out = ggml_swiglu_oai(ctx, ffn_gate, ffn_up, alpha, limit);
|
||||||
|
} else {
|
||||||
|
out = ggml_glu_split(ctx, ffn_gate, ffn_up, glu_op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
if (!use_id) {
|
||||||
|
std::array<int64_t, 4> ne = {k, m, 1, 1};
|
||||||
|
std::array<int64_t, 4> ne0 = {k, n, 1, 1};
|
||||||
|
|
||||||
|
ggml_tensor * cur = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
|
||||||
|
ggml_tensor * gate = with_gate ? ggml_new_tensor(ctx, type, 4, ne0.data()) : nullptr;
|
||||||
|
ggml_tensor * up = ggml_new_tensor(ctx, type, 4, ne0.data());
|
||||||
|
|
||||||
|
ggml_tensor * ffn_up = ggml_mul_mat(ctx, up, cur);
|
||||||
|
if (with_bias) {
|
||||||
|
std::array<int64_t, 4> bias_ne = {ffn_up->ne[0], 1, 1, 1};
|
||||||
|
ggml_tensor * up_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
|
||||||
|
ffn_up = ggml_add(ctx, ffn_up, up_bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * ffn_gate = with_gate ? ggml_mul_mat(ctx, gate, cur) : nullptr;
|
||||||
|
if (with_bias && with_gate) {
|
||||||
|
std::array<int64_t, 4> bias_ne = {ffn_gate->ne[0], 1, 1, 1};
|
||||||
|
ggml_tensor * gate_bias = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias_ne.data());
|
||||||
|
ffn_gate = ggml_add(ctx, ffn_gate, gate_bias);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
|
||||||
|
ggml_set_name(out, "out");
|
||||||
|
return out;
|
||||||
|
} else {
|
||||||
|
ggml_tensor * gates = ggml_new_tensor_3d(ctx, type, k, n, n_mats);
|
||||||
|
ggml_tensor * ups = ggml_new_tensor_3d(ctx, type, k, n, n_mats);
|
||||||
|
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, m);
|
||||||
|
|
||||||
|
if (n_used != n_mats) {
|
||||||
|
ids = ggml_view_2d(ctx, ids, n_used, m, ids->nb[1], 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * cur = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, k, this->b ? 1 : n_used, m);
|
||||||
|
ggml_set_name(cur, "cur");
|
||||||
|
|
||||||
|
ggml_tensor * ffn_up = ggml_mul_mat_id(ctx, ups, cur, ids);
|
||||||
|
if (with_bias) {
|
||||||
|
ggml_tensor * up_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_up->ne[0], n_mats);
|
||||||
|
ffn_up = ggml_add_id(ctx, ffn_up, up_bias_param, ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * ffn_gate = with_gate? ggml_mul_mat_id(ctx, gates, cur, ids) : nullptr;
|
||||||
|
if (with_bias && with_gate) {
|
||||||
|
ggml_tensor * gate_bias_param = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ffn_gate->ne[0], n_mats);
|
||||||
|
ffn_gate = ggml_add_id(ctx, ffn_gate, gate_bias_param, ids);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
|
||||||
|
ggml_set_name(out, "out");
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void initialize_tensors(ggml_context * ctx) override {
|
||||||
|
if (!use_id) {
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
init_tensor_uniform(t);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::random_device rd;
|
||||||
|
std::default_random_engine rng(rd());
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
if (t->type == GGML_TYPE_I32) {
|
||||||
|
if (ggml_is_view_op(t->op)) { continue; }
|
||||||
|
// ids
|
||||||
|
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||||
|
std::vector<int32_t> data(t->ne[0]);
|
||||||
|
for (int i = 0; i < t->ne[0]; i++) {
|
||||||
|
data[i] = i % n_mats;
|
||||||
|
}
|
||||||
|
std::shuffle(data.begin(), data.end(), rng);
|
||||||
|
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
init_tensor_uniform(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double max_nmse_err() override {
|
||||||
|
return 5e-3;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// GGML_OP_SUM
|
// GGML_OP_SUM
|
||||||
struct test_sum : public test_case {
|
struct test_sum : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
@@ -6983,6 +7117,33 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||||||
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
|
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
|
||||||
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
|
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
|
||||||
|
|
||||||
|
for (ggml_type type : base_types) {
|
||||||
|
for (bool with_gate : {false, true}) {
|
||||||
|
for (bool use_id : {false, true}) {
|
||||||
|
for (bool b : {false, true}) {
|
||||||
|
if (!use_id && b) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (bool with_bias : {false, true}) {
|
||||||
|
if (!with_gate && !with_bias) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (ggml_glu_op glu_op : {GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU}) {
|
||||||
|
if (!with_bias && glu_op == GGML_GLU_OP_SWIGLU_OAI) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!with_gate && glu_op != GGML_GLU_OP_SWIGLU) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
|
||||||
|
use_id, 16, 8, b, with_bias, with_gate));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (bool with_norm : {false, true}) {
|
for (bool with_norm : {false, true}) {
|
||||||
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
|
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
|
||||||
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
|
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
|
||||||
|
|||||||
Reference in New Issue
Block a user