mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-03 09:22:01 +00:00
ggml-zdnn: inital backend impl
Signed-off-by: Aaron Teo <aaron.teo1@ibm.com> ggml-zdnn: temp change z17 to arch15 Signed-off-by: Aaron Teo <aaron.teo1@ibm.com> ggml-zdnn: fix build bugs Signed-off-by: Aaron Teo <aaron.teo1@ibm.com>
This commit is contained in:
@@ -183,6 +183,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
|
||||
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
|
||||
option(GGML_WEBGPU "ggml: use WebGPU" OFF)
|
||||
option(GGML_WEBGPU_DEBUG "ggml: enable WebGPU debug output" OFF)
|
||||
option(GGML_ZDNN "ggml: use zDNN" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
|
||||
16
ggml/include/ggml-zdnn.h
Normal file
16
ggml/include/ggml-zdnn.h
Normal file
@@ -0,0 +1,16 @@
|
||||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
GGML_BACKEND_API ggml_backend_t ggml_backend_zdnn_init(void);
|
||||
|
||||
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_zdnn_reg(void);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -371,6 +371,7 @@ ggml_add_backend(RPC)
|
||||
ggml_add_backend(SYCL)
|
||||
ggml_add_backend(Vulkan)
|
||||
ggml_add_backend(WebGPU)
|
||||
ggml_add_backend(zDNN)
|
||||
ggml_add_backend(OpenCL)
|
||||
|
||||
foreach (target ggml-base ggml)
|
||||
|
||||
@@ -49,6 +49,10 @@
|
||||
#include "ggml-webgpu.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_ZDNN
|
||||
#include "ggml-zdnn.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_OPENCL
|
||||
#include "ggml-opencl.h"
|
||||
#endif
|
||||
@@ -180,6 +184,9 @@ struct ggml_backend_registry {
|
||||
#ifdef GGML_USE_WEBGPU
|
||||
register_backend(ggml_backend_webgpu_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_ZDNN
|
||||
register_backend(ggml_backend_zdnn_reg());
|
||||
#endif
|
||||
#ifdef GGML_USE_OPENCL
|
||||
register_backend(ggml_backend_opencl_reg());
|
||||
#endif
|
||||
|
||||
@@ -457,7 +457,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
elseif (${S390X_M} MATCHES "9175|9176")
|
||||
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
|
||||
message(STATUS "z17 target")
|
||||
list(APPEND ARCH_FLAGS -march=z17)
|
||||
list(APPEND ARCH_FLAGS -march=arch15)
|
||||
else()
|
||||
message(STATUS "Unknown target")
|
||||
message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
|
||||
|
||||
38
ggml/src/ggml-zdnn/CMakeLists.txt
Normal file
38
ggml/src/ggml-zdnn/CMakeLists.txt
Normal file
@@ -0,0 +1,38 @@
|
||||
if (GGML_ZDNN)
|
||||
if (DEFINED ZDNN_ROOT)
|
||||
message(STATUS "zdnn: using ZDNN_ROOT override: ${ZDNN_ROOT}")
|
||||
set(ZDNN_HINT "${ZDNN_ROOT}")
|
||||
else()
|
||||
set(ZDNN_HINT "")
|
||||
endif()
|
||||
|
||||
find_path(ZDNN_INCLUDE
|
||||
NAMES zdnn.h
|
||||
HINTS ${ZDNN_HINT} /usr /usr/local
|
||||
PATH_SUFFIXES include)
|
||||
if (ZDNN_INCLUDE)
|
||||
message(STATUS "zdnn: found include: ${ZDNN_INCLUDE}")
|
||||
else()
|
||||
message(FATAL_ERROR "zdnn: include directory not found, please set ZDNN_ROOT to the proper path if necessary")
|
||||
endif()
|
||||
|
||||
find_library(ZDNN_LIB
|
||||
NAMES zdnn
|
||||
HINTS ${ZDNN_HINT} /usr /usr/local
|
||||
PATH_SUFFIXES lib lib64)
|
||||
if (ZDNN_LIB)
|
||||
message(STATUS "zdnn: found library: ${ZDNN_LIB}")
|
||||
else()
|
||||
message(FATAL_ERROR "zdnn: library not found, please set ZDNN_ROOT to the proper path if necessary")
|
||||
endif()
|
||||
|
||||
file(GLOB GGML_SOURCES_ZDNN "*.c" "*.cpp")
|
||||
file(GLOB GGML_HEADERS_ZDNN "*.h" "*.hpp")
|
||||
|
||||
ggml_add_backend_library(ggml-zdnn ${GGML_HEADERS_ZDNN} ${GGML_SOURCES_ZDNN})
|
||||
target_link_libraries(ggml-zdnn PRIVATE ${ZDNN_LIB})
|
||||
target_include_directories(ggml-zdnn PRIVATE ${ZDNN_INCLUDE})
|
||||
target_link_directories(ggml-zdnn PRIVATE ${ZDNN_LIB})
|
||||
|
||||
target_compile_definitions(ggml-zdnn PRIVATE GGML_ZDNN GGML_USE_ZDNN)
|
||||
endif()
|
||||
59
ggml/src/ggml-zdnn/ggml-zdnn-impl.h
Normal file
59
ggml/src/ggml-zdnn/ggml-zdnn-impl.h
Normal file
@@ -0,0 +1,59 @@
|
||||
#ifndef GGML_ZDNN_IMPL
|
||||
#define GGML_ZDNN_IMPL
|
||||
|
||||
#include "zdnn.h"
|
||||
#include "ggml.h"
|
||||
#include "ggml-zdnn.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vecintrin.h>
|
||||
|
||||
#define GGML_ZDNN_NAME "zDNN"
|
||||
#define GGML_ZDNN_VERSION ZDNN_VERNUM
|
||||
|
||||
#define vec_neg(a) (-(a)) // Vector Negate
|
||||
#define vec_add(a, b) ((a) + (b)) // Vector Add
|
||||
#define vec_sub(a, b) ((a) - (b)) // Vector Subtract
|
||||
#define vec_mul(a, b) ((a) * (b)) // Vector Multiply
|
||||
#define vec_div(a, b) ((a) / (b)) // Vector Divide
|
||||
#define vec_sl(a, b) ((a) << (b)) // Vector Shift Left
|
||||
#define vec_sra(a, b) ((a) >> (b)) // Vector Shift Right
|
||||
#define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic
|
||||
#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet
|
||||
#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet
|
||||
|
||||
#ifndef vec_and
|
||||
#define vec_and(a, b) ((a) & (b)) // Vector AND
|
||||
#endif
|
||||
|
||||
#ifndef vec_or
|
||||
#define vec_or(a, b) ((a) | (b)) // Vector OR
|
||||
#endif
|
||||
|
||||
#ifndef vec_xor
|
||||
#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
|
||||
#endif
|
||||
|
||||
typedef signed char char8x16_t __attribute__((vector_size(16)));
|
||||
typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef int8_t int8x16_t __attribute__((vector_size(16)));
|
||||
typedef int16_t int16x8_t __attribute__((vector_size(16)));
|
||||
typedef int32_t int32x4_t __attribute__((vector_size(16)));
|
||||
typedef uint8_t uint8x16_t __attribute__((vector_size(16)));
|
||||
typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
|
||||
typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef float float32x4_t __attribute__((vector_size(16)));
|
||||
typedef double double64x2_t __attribute__((vector_size(16)));
|
||||
|
||||
typedef signed long long long64x2_t __attribute__((vector_size(16)));
|
||||
typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
|
||||
|
||||
#define ZDNN_CHECK(stmt) \
|
||||
do { \
|
||||
zdnn_status status = (stmt); \
|
||||
GGML_ASSERT(status == ZDNN_OK); \
|
||||
} while (0);
|
||||
|
||||
#endif // GGML_ZDNN_IMPL
|
||||
622
ggml/src/ggml-zdnn/ggml-zdnn.cpp
Normal file
622
ggml/src/ggml-zdnn/ggml-zdnn.cpp
Normal file
@@ -0,0 +1,622 @@
|
||||
#include "zdnn.h"
|
||||
#include "ggml-zdnn.h"
|
||||
#include "ggml-zdnn-impl.h"
|
||||
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#include <csignal>
|
||||
#include <unistd.h>
|
||||
|
||||
struct zdnn_extra {
|
||||
zdnn_tensor_desc pre_tfm_desc;
|
||||
zdnn_tensor_desc tfm_desc;
|
||||
zdnn_ztensor ztensor;
|
||||
|
||||
struct zdnn_extra * extra; // for bias, etc.
|
||||
};
|
||||
|
||||
struct ggml_backend_zdnn_context {
|
||||
int n_threads = GGML_DEFAULT_N_THREADS;
|
||||
};
|
||||
|
||||
inline zdnn_data_types ggml_zdnn_type_mapping(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return FP32;
|
||||
case GGML_TYPE_F16:
|
||||
return FP16;
|
||||
case GGML_TYPE_BF16:
|
||||
return BFLOAT;
|
||||
case GGML_TYPE_I8:
|
||||
return INT8;
|
||||
case GGML_TYPE_I32:
|
||||
return INT32;
|
||||
case GGML_TYPE_Q8_0:
|
||||
return INT8;
|
||||
default:
|
||||
GGML_ABORT("%s: fatal: unable to determine zTensor data type",
|
||||
__func__);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
inline void ggml_zdnn_create_tensor(zdnn_tensor_desc & pre_tfm_desc,
|
||||
zdnn_tensor_desc & tfm_desc,
|
||||
zdnn_ztensor & ztensor,
|
||||
const ggml_tensor * src,
|
||||
const int64_t * ne,
|
||||
const zdnn_data_layouts layout) {
|
||||
zdnn_init_pre_transformed_desc(
|
||||
layout,
|
||||
ggml_zdnn_type_mapping(src->type),
|
||||
&pre_tfm_desc,
|
||||
ne[3], ne[2], ne[1], ne[0]
|
||||
);
|
||||
|
||||
ZDNN_CHECK(zdnn_generate_transformed_desc(&pre_tfm_desc, &tfm_desc));
|
||||
ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&pre_tfm_desc, &tfm_desc, &ztensor));
|
||||
}
|
||||
|
||||
inline void ggml_zdnn_load_tensor(zdnn_ztensor & ztensor,
|
||||
void * buffer) {
|
||||
ZDNN_CHECK(zdnn_transform_ztensor(&ztensor, buffer));
|
||||
}
|
||||
|
||||
static void ggml_backend_zdnn_mul_mat(ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
const ggml_tensor * weights = src0;
|
||||
const ggml_tensor * inputs = src1;
|
||||
ggml_tensor * output = dst;
|
||||
|
||||
zdnn_tensor_desc pre_tfm_desc_weights, tfm_desc_weights;
|
||||
zdnn_tensor_desc pre_tfm_desc_inputs, tfm_desc_inputs;
|
||||
zdnn_tensor_desc pre_tfm_desc_bias, tfm_desc_bias;
|
||||
zdnn_tensor_desc pre_tfm_desc_output, tfm_desc_output;
|
||||
|
||||
zdnn_ztensor ztensor_weights, ztensor_inputs, ztensor_bias, ztensor_output;
|
||||
|
||||
const int64_t weights_rows = ne01;
|
||||
const int64_t weights_cols = ne00;
|
||||
const int64_t inputs_rows = ne11;
|
||||
const int64_t inputs_cols = ne10;
|
||||
|
||||
assert(inputs_cols == weights_cols);
|
||||
|
||||
const int64_t output_rows = dst->ne[1];
|
||||
const int64_t output_cols = dst->ne[0];
|
||||
|
||||
const int64_t inputs_dim [GGML_MAX_DIMS] = { 1, 1, inputs_cols, inputs_rows };
|
||||
const int64_t weights_dim[GGML_MAX_DIMS] = { 1, 1, weights_cols, weights_rows };
|
||||
const int64_t bias_dim [GGML_MAX_DIMS] = { 1, 1, 1, output_cols };
|
||||
const int64_t output_dim [GGML_MAX_DIMS] = { 1, 1, output_cols, output_rows };
|
||||
|
||||
ggml_zdnn_create_tensor(pre_tfm_desc_inputs, tfm_desc_inputs, ztensor_inputs, src1, inputs_dim, ZDNN_2D);
|
||||
ggml_zdnn_create_tensor(pre_tfm_desc_weights, tfm_desc_weights, ztensor_weights, src0, weights_dim, ZDNN_2D);
|
||||
ggml_zdnn_create_tensor(pre_tfm_desc_bias, tfm_desc_bias, ztensor_bias, dst, bias_dim, ZDNN_1D);
|
||||
ggml_zdnn_create_tensor(pre_tfm_desc_output, tfm_desc_output, ztensor_output, dst, output_dim, ZDNN_2D);
|
||||
|
||||
const size_t weights_size = ggml_element_size(src0);
|
||||
|
||||
void * bias_data = (void *)calloc(output_cols, sizeof(ggml_element_size(dst)));
|
||||
|
||||
ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_weights, weights->data));
|
||||
ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_inputs, inputs->data));
|
||||
ZDNN_CHECK(zdnn_transform_ztensor(&ztensor_bias, bias_data));
|
||||
|
||||
ZDNN_CHECK(zdnn_matmul_transpose_op(&ztensor_inputs, &ztensor_weights, &ztensor_bias,
|
||||
false, true, MATMUL_OP_ADDITION, &ztensor_output));
|
||||
ZDNN_CHECK(zdnn_transform_origtensor(&ztensor_output, output->data));
|
||||
|
||||
ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_weights));
|
||||
ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_inputs));
|
||||
ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_bias));
|
||||
ZDNN_CHECK(zdnn_free_ztensor_buffer(&ztensor_output));
|
||||
|
||||
free(bias_data);
|
||||
}
|
||||
|
||||
static void ggml_backend_zdnn_mul_mat_dispatch(ggml_backend_zdnn_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_UNUSED(ctx);
|
||||
|
||||
bool use_mul_mat_vec =
|
||||
(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F16)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
|
||||
|
||||
bool use_mul_mat_vec_q =
|
||||
ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
bool use_mul_mat_q =
|
||||
ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
||||
// debug helpers
|
||||
// GGML_LOG_INFO("%s: use_mul_mat_vec = %d\n", __func__, use_mul_mat_vec);
|
||||
// GGML_LOG_INFO("%s: use_mul_mat_vec_q = %d\n", __func__, use_mul_mat_vec_q);
|
||||
// GGML_LOG_INFO("%s: use_mul_mat_q = %d\n", __func__, use_mul_mat_q);
|
||||
// GGML_LOG_INFO("%s: src0: %8d %8d %8d %8d\n", __func__, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
||||
// GGML_LOG_INFO("%s: %8d %8d %8d %8d\n", __func__, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
|
||||
// GGML_LOG_INFO("%s: src1: %8d %8d %8d %8d\n", __func__, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
|
||||
// GGML_LOG_INFO("%s: %8d %8d %8d %8d\n", __func__, src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
|
||||
// GGML_LOG_INFO("%s: src0 is contiguous %d, transposed %d, type = %s, name = %s\n", __func__, ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||
// GGML_LOG_INFO("%s: src1 is contiguous %d, transposed %d, type = %s, name = %s\n", __func__, ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||
|
||||
if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16
|
||||
&& !ggml_is_transposed(src0) && !ggml_is_transposed(src1)
|
||||
&& src1->ne[2] * src1->ne[3] > 1) {
|
||||
// general KQ + KQV multi-batch
|
||||
GGML_LOG_INFO("%s: using zdnn_mul_mat_batched for KQ + KQV multi-batch\n", __func__);
|
||||
// ggml_zdnn_mul_mat_batched(ctx, src0, src1, dst);
|
||||
} else if (use_mul_mat_vec) {
|
||||
GGML_LOG_INFO("%s: using zdnn_op_mul_mat_vec for vector multiplication\n", __func__);
|
||||
// ggml_zdnn_op_mul_mat(ctx, src0, src1, dst, ggml_zdnn_op_mul_mat_vec, nullptr);
|
||||
} else if (use_mul_mat_vec_q) {
|
||||
GGML_LOG_INFO("%s: using zdnn_op_mul_mat_vec_q for quantized vector multiplication\n", __func__);
|
||||
// ggml_zdnn_op_mul_mat(ctx, src0, src1, dst, ggml_zdnn_op_mul_mat_vec_q, ggml_zdnn_quantize_row_q8_1);
|
||||
} else if (use_mul_mat_q) {
|
||||
GGML_LOG_INFO("%s: using zdnn_op_mul_mat_q for quantized matrix multiplication\n", __func__);
|
||||
// ggml_zdnn_op_mul_mat(ctx, src0, src1, dst, ggml_zdnn_op_mul_mat_q, ggml_zdnn_quantize_mmq_q8_1);
|
||||
} else {
|
||||
// GGML_LOG_INFO("%s: using zdnn_op_mul_mat for general matrix multiplication\n", __func__);
|
||||
ggml_backend_zdnn_mul_mat(ctx, src0, src1, dst);
|
||||
}
|
||||
}
|
||||
|
||||
static bool ggml_backend_zdnn_compute_forward(ggml_backend_zdnn_context * ctx, ggml_tensor * dst) {
|
||||
switch (dst->op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
ggml_backend_zdnn_mul_mat_dispatch(ctx, dst->src[0], dst->src[1], dst);
|
||||
break;
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static const char * ggml_backend_zdnn_get_name(ggml_backend_t backend) {
|
||||
return GGML_ZDNN_NAME;
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static void ggml_backend_zdnn_free(ggml_backend_t backend) {
|
||||
ggml_backend_zdnn_context * ctx = (ggml_backend_zdnn_context *)backend->context;
|
||||
delete ctx;
|
||||
delete backend;
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
||||
ggml_backend_zdnn_context * ctx = (ggml_backend_zdnn_context *)backend->context;
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
if (ggml_is_empty(node)
|
||||
|| node->op == GGML_OP_NONE
|
||||
|| node->op == GGML_OP_RESHAPE
|
||||
|| node->op == GGML_OP_VIEW
|
||||
|| node->op == GGML_OP_PERMUTE
|
||||
|| node->op == GGML_OP_TRANSPOSE) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool ok = ggml_backend_zdnn_compute_forward(ctx, node);
|
||||
if (!ok) {
|
||||
GGML_LOG_ERROR("%s: unsupported op %s (%s)\n",
|
||||
__func__, node->name, ggml_op_name(node->op));
|
||||
}
|
||||
|
||||
GGML_ASSERT(ok);
|
||||
}
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
|
||||
GGML_UNUSED(backend);
|
||||
}
|
||||
|
||||
static ggml_backend_i ggml_backend_zdnn_i = {
|
||||
/* .get_name = */ ggml_backend_zdnn_get_name,
|
||||
/* .free = */ ggml_backend_zdnn_free,
|
||||
/* .set_tensor_async = */ NULL,
|
||||
/* .get_tensor_async = */ NULL,
|
||||
/* .cpy_tensor_async = */ NULL,
|
||||
/* .synchronize = */ NULL,
|
||||
/* .graph_plan_create = */ NULL,
|
||||
/* .graph_plan_free = */ NULL,
|
||||
/* .graph_plan_update = */ NULL,
|
||||
/* .graph_plan_compute = */ NULL,
|
||||
/* .graph_compute = */ ggml_backend_zdnn_graph_compute,
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_zdnn_guid(void) {
|
||||
// guid spells out IBM-NNPA-ACCELER
|
||||
static ggml_guid guid = { 0x49, 0x42, 0x4D, 0x2D, 0x4E, 0x4E, 0x50, 0x41,
|
||||
0x2D, 0x41, 0x43, 0x43, 0x45, 0x4C, 0x45, 0x52 };
|
||||
|
||||
return &guid;
|
||||
}
|
||||
|
||||
ggml_backend_t ggml_backend_zdnn_init(void) {
|
||||
ggml_backend_zdnn_context * ctx = new ggml_backend_zdnn_context;
|
||||
|
||||
ggml_backend_t backend = new ggml_backend {
|
||||
/* .guid = */ ggml_backend_zdnn_guid(),
|
||||
/* .iface = */ ggml_backend_zdnn_i,
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_zdnn_reg(), 0),
|
||||
/* .context = */ ctx,
|
||||
};
|
||||
|
||||
return backend;
|
||||
}
|
||||
|
||||
bool ggml_backend_is_zdnn(ggml_backend_t backend) {
|
||||
return backend != NULL &&
|
||||
ggml_guid_matches(backend->guid, ggml_backend_zdnn_guid());
|
||||
}
|
||||
|
||||
void ggml_backend_zdnn_set_n_threads(ggml_backend_t backend_zdnn, int n_threads) {
|
||||
GGML_ASSERT(ggml_backend_is_zdnn(backend_zdnn));
|
||||
|
||||
ggml_backend_zdnn_context * ctx = (ggml_backend_zdnn_context *)backend_zdnn->context;
|
||||
ctx->n_threads = n_threads;
|
||||
}
|
||||
|
||||
static const char * ggml_backend_zdnn_device_get_name(ggml_backend_dev_t dev) {
|
||||
return GGML_ZDNN_NAME;
|
||||
}
|
||||
|
||||
static const char * ggml_backend_zdnn_device_get_description(ggml_backend_dev_t dev) {
|
||||
return GGML_ZDNN_NAME;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_backend_zdnn_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_zdnn_device_get_type(ggml_backend_dev_t dev) {
|
||||
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static void ggml_backend_zdnn_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_zdnn_device_get_name(dev);
|
||||
props->description = ggml_backend_zdnn_device_get_description(dev);
|
||||
props->type = ggml_backend_zdnn_device_get_type(dev);
|
||||
ggml_backend_zdnn_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
props->caps = {
|
||||
/* .async = */ false,
|
||||
/* .host_buffer = */ false,
|
||||
/* .buffer_from_host_ptr = */ true,
|
||||
/* .events = */ false,
|
||||
};
|
||||
}
|
||||
|
||||
static ggml_backend_t ggml_backend_zdnn_device_init_backend(ggml_backend_dev_t dev, const char * params) {
|
||||
return ggml_backend_zdnn_init();
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
GGML_UNUSED(params);
|
||||
}
|
||||
|
||||
static void * ggml_backend_zdnn_buffer_get_base(ggml_backend_buffer_t buffer) {
|
||||
uintptr_t data = (uintptr_t)buffer->context;
|
||||
if (data % 256 != 0) {
|
||||
data = GGML_PAD(data, 256);
|
||||
}
|
||||
|
||||
return (void *)data;
|
||||
}
|
||||
|
||||
static ggml_status ggml_backend_zdnn_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
|
||||
if (tensor->view_src != NULL) {
|
||||
assert(tensor->view_src->buffer->buft == buffer->buft);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
zdnn_extra * extra = (zdnn_extra *)malloc(sizeof(zdnn_extra));
|
||||
const int64_t dims[GGML_MAX_DIMS] = { 1, 1, tensor->ne[0], tensor->ne[1] };
|
||||
|
||||
zdnn_init_pre_transformed_desc(
|
||||
ZDNN_2D,
|
||||
ggml_zdnn_type_mapping(tensor->type),
|
||||
&extra->pre_tfm_desc,
|
||||
dims[3], dims[2], dims[1], dims[0]
|
||||
);
|
||||
|
||||
ZDNN_CHECK(zdnn_generate_transformed_desc(&extra->pre_tfm_desc, &extra->tfm_desc));
|
||||
ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&extra->pre_tfm_desc, &extra->tfm_desc, &extra->ztensor));
|
||||
|
||||
if (tensor->op == GGML_OP_MUL_MAT) {
|
||||
zdnn_extra * bias_extra = (zdnn_extra *)malloc(sizeof(zdnn_extra));
|
||||
const int64_t bias_dims[GGML_MAX_DIMS] = { 1, 1, 1, tensor->ne[0] };
|
||||
|
||||
zdnn_init_pre_transformed_desc(
|
||||
ZDNN_1D,
|
||||
ggml_zdnn_type_mapping(tensor->type),
|
||||
&bias_extra->pre_tfm_desc,
|
||||
bias_dims[3], bias_dims[2], bias_dims[1], bias_dims[0]
|
||||
);
|
||||
ZDNN_CHECK(zdnn_generate_transformed_desc(&bias_extra->pre_tfm_desc, &bias_extra->tfm_desc));
|
||||
ZDNN_CHECK(zdnn_init_ztensor_with_malloc(&bias_extra->pre_tfm_desc, &bias_extra->tfm_desc, &bias_extra->ztensor));
|
||||
|
||||
extra->extra = bias_extra;
|
||||
}
|
||||
|
||||
tensor->extra = extra;
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
static void ggml_backend_zdnn_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
||||
ggml_aligned_free(buffer->context, buffer->size);
|
||||
}
|
||||
|
||||
static void ggml_backend_zdnn_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
||||
memset((char *)tensor->data + offset, value, size);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_zdnn_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
||||
memcpy((char *)tensor->data + offset, data, size);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_zdnn_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
||||
memcpy(data, (const char *)tensor->data + offset, size);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_zdnn_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
||||
memset(buffer->context, value, buffer->size);
|
||||
}
|
||||
|
||||
static const ggml_backend_buffer_i ggml_backend_zdnn_buffer_i = {
|
||||
/* .free_buffer = */ ggml_backend_zdnn_buffer_free_buffer, // zdnn buffers are not owned by the backend
|
||||
/* .get_base = */ ggml_backend_zdnn_buffer_get_base,
|
||||
/* .init_tensor = */ ggml_backend_zdnn_buffer_init_tensor,
|
||||
/* .memset_tensor = */ ggml_backend_zdnn_buffer_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_zdnn_buffer_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_zdnn_buffer_get_tensor,
|
||||
/* .cpy_tensor = */ NULL,
|
||||
/* .clear = */ ggml_backend_zdnn_buffer_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
static const ggml_backend_buffer_i ggml_backend_zdnn_buffer_from_ptr_i = {
|
||||
/* .free_buffer = */ NULL, // ptr is not owned by the buffer
|
||||
/* .get_base = */ ggml_backend_zdnn_buffer_get_base,
|
||||
/* .init_tensor = */ ggml_backend_zdnn_buffer_init_tensor,
|
||||
/* .memset_tensor = */ ggml_backend_zdnn_buffer_memset_tensor,
|
||||
/* .set_tensor = */ ggml_backend_zdnn_buffer_set_tensor,
|
||||
/* .get_tensor = */ ggml_backend_zdnn_buffer_get_tensor,
|
||||
/* .cpy_tensor = */ NULL,
|
||||
/* .clear = */ ggml_backend_zdnn_buffer_clear,
|
||||
/* .reset = */ NULL,
|
||||
};
|
||||
|
||||
static const char * ggml_backend_zdnn_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return GGML_ZDNN_NAME;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_zdnn_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||
void * data = ggml_aligned_malloc(size);
|
||||
if (data == NULL) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate %zu bytes\n", __func__, size);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return ggml_backend_buffer_init(buft, ggml_backend_zdnn_buffer_i, data, size);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_zdnn_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
||||
return 256;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static bool ggml_backend_zdnn_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
||||
return true;
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_zdnn_device_get_buffer_type(ggml_backend_dev_t dev) {
|
||||
static ggml_backend_buffer_type ggml_backend_zdnn_buffer_type = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_zdnn_buffer_type_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_zdnn_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_zdnn_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
||||
/* .is_host = */ ggml_backend_zdnn_buffer_type_is_host,
|
||||
},
|
||||
/* .device = */ NULL,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_zdnn_buffer_type;
|
||||
}
|
||||
|
||||
static const char * ggml_backend_zdnn_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
|
||||
return GGML_ZDNN_NAME "_Mapped";
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_type_t ggml_backend_zdnn_buffer_from_ptr_type(void) {
|
||||
static ggml_backend_buffer_type ggml_backend_zdnn_buffer_type = {
|
||||
/* .iface = */ {
|
||||
/* .get_name = */ ggml_backend_zdnn_buffer_from_ptr_type_get_name,
|
||||
/* .alloc_buffer = */ ggml_backend_zdnn_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_zdnn_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
||||
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
||||
/* .is_host = */ ggml_backend_zdnn_buffer_type_is_host,
|
||||
},
|
||||
/* .device = */ NULL,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_zdnn_buffer_type;
|
||||
}
|
||||
|
||||
static ggml_backend_buffer_t ggml_backend_zdnn_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
||||
GGML_ASSERT((uintptr_t)ptr % 256 == 0 && "buffer pointer must be aligned");
|
||||
return ggml_backend_buffer_init(ggml_backend_zdnn_buffer_from_ptr_type(), ggml_backend_zdnn_buffer_from_ptr_i, ptr, size);
|
||||
}
|
||||
|
||||
static bool ggml_backend_zdnn_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
return true;
|
||||
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
const ggml_tensor * src1 = op->src[1];
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
|
||||
const int64_t ne0 = op->ne[0];
|
||||
const int64_t ne1 = op->ne[1];
|
||||
|
||||
const int64_t max_batch = zdnn_get_nnpa_max_dim_idx_size();
|
||||
|
||||
return ggml_is_contiguous(src0) &&
|
||||
ggml_is_contiguous(src1) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
(ne0 <= max_batch && ne1 <= max_batch && ne10 <= max_batch) &&
|
||||
(src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
|
||||
}
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static bool ggml_backend_zdnn_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
||||
return buft->iface.get_name == ggml_backend_zdnn_buffer_type_get_name;
|
||||
|
||||
GGML_UNUSED(dev);
|
||||
}
|
||||
|
||||
static ggml_backend_device_i ggml_backend_zdnn_device_i = {
|
||||
/* .get_name = */ ggml_backend_zdnn_device_get_name,
|
||||
/* .get_description = */ ggml_backend_zdnn_device_get_description,
|
||||
/* .get_memory = */ ggml_backend_zdnn_device_get_memory,
|
||||
/* .get_type = */ ggml_backend_zdnn_device_get_type,
|
||||
/* .get_props = */ ggml_backend_zdnn_device_get_props,
|
||||
/* .init_backend = */ ggml_backend_zdnn_device_init_backend,
|
||||
/* .get_buffer_type = */ ggml_backend_zdnn_device_get_buffer_type,
|
||||
/* .get_host_buffer_type = */ NULL,
|
||||
/* .buffer_from_host_ptr = */ ggml_backend_zdnn_device_buffer_from_host_ptr,
|
||||
/* .supports_op = */ ggml_backend_zdnn_device_supports_op,
|
||||
/* .supports_buft = */ ggml_backend_zdnn_device_supports_buft,
|
||||
/* .offload_op = */ NULL,
|
||||
/* .event_new = */ NULL,
|
||||
/* .event_free = */ NULL,
|
||||
/* .event_synchronize = */ NULL,
|
||||
};
|
||||
|
||||
//
|
||||
// backend registry
|
||||
//
|
||||
|
||||
static const char * ggml_backend_zdnn_reg_get_name(ggml_backend_reg_t reg) {
|
||||
return GGML_ZDNN_NAME;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_zdnn_reg_get_device_count(ggml_backend_reg_t reg) {
|
||||
return 1;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static ggml_backend_dev_t ggml_backend_zdnn_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
||||
GGML_ASSERT(index == 0);
|
||||
|
||||
static ggml_backend_device ggml_backend_zdnn_device = {
|
||||
/* .iface = */ ggml_backend_zdnn_device_i,
|
||||
/* .reg = */ reg,
|
||||
/* .context = */ nullptr,
|
||||
};
|
||||
|
||||
return &ggml_backend_zdnn_device;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
GGML_UNUSED(index);
|
||||
}
|
||||
|
||||
static void * ggml_backend_zdnn_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
||||
if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
||||
return (void *)ggml_backend_zdnn_set_n_threads;
|
||||
}
|
||||
return NULL;
|
||||
|
||||
GGML_UNUSED(reg);
|
||||
}
|
||||
|
||||
static const ggml_backend_reg_i ggml_backend_zdnn_reg_i = {
|
||||
/* .get_name = */ ggml_backend_zdnn_reg_get_name,
|
||||
/* .get_device_count = */ ggml_backend_zdnn_reg_get_device_count,
|
||||
/* .get_device = */ ggml_backend_zdnn_reg_get_device,
|
||||
/* .get_proc_address = */ ggml_backend_zdnn_get_proc_address,
|
||||
};
|
||||
|
||||
ggml_backend_reg_t ggml_backend_zdnn_reg(void) {
|
||||
static ggml_backend_reg ggml_backend_zdnn_reg = {
|
||||
/* .api_version = */ GGML_ZDNN_VERSION,
|
||||
/* .iface = */ ggml_backend_zdnn_reg_i,
|
||||
/* .context = */ NULL,
|
||||
};
|
||||
|
||||
return &ggml_backend_zdnn_reg;
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_IMPL(ggml_backend_zdnn_reg)
|
||||
660
ggml/src/ggml-zdnn/zdnn.h
Normal file
660
ggml/src/ggml-zdnn/zdnn.h
Normal file
@@ -0,0 +1,660 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
/*
|
||||
* Copyright IBM Corp. 2021, 2024
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef ZDNN_ZDNN_H_
|
||||
#define ZDNN_ZDNN_H_
|
||||
|
||||
#include <inttypes.h>
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// NOTE:
|
||||
// Ensure that symbols in zdnn.h and zdnn.map are in sync!
|
||||
// Please also have a look at zdnn.map how to add, update or remove a symbol.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Initializer and global variables
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
void zdnn_init();
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// zDNN Status
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// NOTE:
|
||||
// Update status.c and zdnn_private.h after any status modification!
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Status categories
|
||||
#define ZDNN_WARNING 0x00020000
|
||||
#define ZDNN_PARAMETER_ERROR 0x00040000
|
||||
#define ZDNN_DATA_ERROR 0x00100000
|
||||
#define ZDNN_HW_ERROR 0x000c0000
|
||||
|
||||
// clang-format off
|
||||
typedef enum zdnn_status {
|
||||
// ----------------------------------------------------------------
|
||||
ZDNN_OK = 0x00000000, // Success.
|
||||
// ----------------------------------------------------------------
|
||||
ZDNN_ELEMENT_RANGE_VIOLATION = ZDNN_WARNING + 0x0001, // zAIU operation resulted in data that was out of the normal range.
|
||||
// ----------------------------------------------------------------
|
||||
ZDNN_INVALID_SHAPE = ZDNN_PARAMETER_ERROR + 0x0001, // Invalid shape information in one (or more) of the input/output tensor(s).
|
||||
ZDNN_INVALID_LAYOUT, // Invalid layout information in one (or more) of the input/output tensor(s).
|
||||
ZDNN_INVALID_TYPE, // Invalid type information in one (or more) of the input/output tensor(s).
|
||||
ZDNN_INVALID_FORMAT, // Invalid format information in one (or more) of the input/output tensor(s).
|
||||
ZDNN_INVALID_DIRECTION, // Invalid RNN direction.
|
||||
ZDNN_INVALID_CONCAT_INFO, // Invalid concatenation info.
|
||||
ZDNN_INVALID_STRIDE_PADDING, // Invalid padding type parameter for current strides
|
||||
ZDNN_INVALID_STRIDES, // Invalid stride height or width parameter.
|
||||
ZDNN_MISALIGNED_PARMBLOCK, // NNPA parameter block is not on double word boundary.
|
||||
ZDNN_INVALID_CLIPPING_VALUE, // Invalid clipping for the specified operation.
|
||||
ZDNN_INVALID_ADJUSTMENT_FACTOR, // Invalid adjustment for the specified operation.
|
||||
ZDNN_INVALID_EPSILON, // Invalid epsilon for the specified operation.
|
||||
ZDNN_INVALID_TRANSFORM_TYPE, // Invalid transformation type
|
||||
ZDNN_INVALID_BETA, // Invalid beta value for the specified operation.
|
||||
ZDNN_INVALID_GAMMA, // Invalid gamma value for the specified operation.
|
||||
ZDNN_INVALID_BESSEL_CORRECTION, // Invalid bessel correction value for the specified operation.
|
||||
ZDNN_INVALID_SCALE, // Invalid scale value for the specified operation.
|
||||
ZDNN_INVALID_OFFSET, // Invalid offset value for the specified operation.
|
||||
// ----------------------------------------------------------------
|
||||
ZDNN_ALLOCATION_FAILURE = ZDNN_DATA_ERROR + 0x0001, // Can not allocate storage.
|
||||
ZDNN_INVALID_BUFFER, // Buffer address is NULL or not on 4K-byte boundary, or insufficient buffer size.
|
||||
ZDNN_CONVERT_FAILURE, // Floating point data conversion failure.
|
||||
ZDNN_INVALID_STATE, // Invalid zTensor state.
|
||||
ZDNN_UNSUPPORTED_AIU_EXCEPTION, // zAIU operation returned an unexpected exception.
|
||||
// ----------------------------------------------------------------
|
||||
ZDNN_UNSUPPORTED_PARMBLOCK = ZDNN_HW_ERROR + 0x0001, // NNPA parameter block format is not supported by the model.
|
||||
ZDNN_UNAVAILABLE_FUNCTION, // Specified NNPA function is not defined or installed on the machine.
|
||||
ZDNN_UNSUPPORTED_FORMAT = ZDNN_HW_ERROR + 0x0010, // Specified tensor data layout format is not supported.
|
||||
ZDNN_UNSUPPORTED_TYPE, // Specified tensor data type is not supported.
|
||||
ZDNN_EXCEEDS_MDIS, // Tensor dimension exceeds maximum dimension index size (MDIS).
|
||||
ZDNN_EXCEEDS_MTS, // Total number of elements in tensor exceeds maximum tensor size. (MTS).
|
||||
ZDNN_MISALIGNED_TENSOR, // Tensor address is not on 4K-byte boundary.
|
||||
ZDNN_MISALIGNED_SAVEAREA, // Function specific save area address is not on 4K-byte boundary.
|
||||
// ----------------------------------------------------------------
|
||||
// Function specific response code (F00x)
|
||||
ZDNN_FUNC_RC_F000 = ZDNN_HW_ERROR + 0xF000, // Function specific response code (F000).
|
||||
ZDNN_FUNC_RC_F001, // Function specific response code (F001).
|
||||
ZDNN_FUNC_RC_F002, // Function specific response code (F002).
|
||||
ZDNN_FUNC_RC_F003, // Function specific response code (F003).
|
||||
ZDNN_FUNC_RC_F004, // Function specific response code (F004).
|
||||
ZDNN_FUNC_RC_F005, // Function specific response code (F005).
|
||||
ZDNN_FUNC_RC_F006, // Function specific response code (F006).
|
||||
ZDNN_FUNC_RC_F007, // Function specific response code (F007).
|
||||
ZDNN_FUNC_RC_F008, // Function specific response code (F008).
|
||||
ZDNN_FUNC_RC_F009, // Function specific response code (F009).
|
||||
|
||||
// ----------------------------------------------------------------
|
||||
} zdnn_status;
|
||||
// clang-format on
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// NNPA hardware defined values as described in
|
||||
// z/Architecture - Principles of Operation
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
typedef enum nnpa_function_code {
|
||||
NNPA_QAF = 0,
|
||||
NNPA_ADD = 16,
|
||||
NNPA_SUB = 17,
|
||||
NNPA_MUL = 18,
|
||||
NNPA_DIV = 19,
|
||||
NNPA_MIN = 20,
|
||||
NNPA_MAX = 21,
|
||||
NNPA_LOG = 32,
|
||||
NNPA_EXP = 33,
|
||||
NNPA_SQRT = 34,
|
||||
NNPA_INVSQRT = 35,
|
||||
// reserved = 48
|
||||
NNPA_RELU = 49,
|
||||
NNPA_TANH = 50,
|
||||
NNPA_SIGMOID = 51,
|
||||
NNPA_SOFTMAX = 52,
|
||||
NNPA_GELU = 53,
|
||||
NNPA_BATCHNORMALIZATION = 64,
|
||||
NNPA_MOMENTS = 65,
|
||||
NNPA_LAYERNORM = 66,
|
||||
NNPA_NORM = 67,
|
||||
NNPA_MAXPOOL2D = 80,
|
||||
NNPA_AVGPOOL2D = 81,
|
||||
NNPA_LSTMACT = 96,
|
||||
NNPA_GRUACT = 97,
|
||||
NNPA_CONVOLUTION = 112,
|
||||
NNPA_MATMUL_OP = 113,
|
||||
NNPA_MATMUL_OP_BCAST23 = 114,
|
||||
NNPA_MATMUL_OP_BCAST1 = 115,
|
||||
NNPA_TRANSFORM = 240,
|
||||
NNPA_REDUCE = 241
|
||||
} nnpa_function_code;
|
||||
|
||||
typedef enum nnpa_parmblk_format {
|
||||
NNPA_PARMBLKFORMAT_0 = 0,
|
||||
NNPA_PARMBLKFORMAT_1 = 1,
|
||||
} nnpa_parmblk_format;
|
||||
|
||||
typedef enum nnpa_data_type {
|
||||
NNPA_DATATYPE_1 = 0,
|
||||
NNPA_32_BIT_BINARY_FP_SHORT = 6,
|
||||
NNPA_8_BIT_BINARY_INT = 8,
|
||||
NNPA_32_BIT_BINARY_INT = 10
|
||||
} nnpa_data_type;
|
||||
|
||||
typedef enum nnpa_layout_format {
|
||||
NNPA_LAYOUTFMT_4DFEATURE = 0,
|
||||
NNPA_LAYOUTFMT_4DKERNEL = 1,
|
||||
NNPA_LAYOUTFMT_4DWEIGHTS = 2,
|
||||
NNPA_LAYOUTFMT_4DGENERIC = 31
|
||||
} nnpa_layout_format;
|
||||
|
||||
typedef enum nnpa_bfp_format {
|
||||
// 0 is reversed
|
||||
NNPA_BFPFMT_TINY = 1,
|
||||
NNPA_BFPFMT_SHORT = 2
|
||||
} nnpa_bfp_format;
|
||||
|
||||
// NNPA_SOFTMAX, NNPA_REDUCE, and NNPA_TRANSFORM require 8K work area
|
||||
#define ZDNN_SOFTMAX_SAVEAREA_SIZE 8 * 1024
|
||||
#define ZDNN_8K_SAVEAREA_SIZE 8 * 1024
|
||||
|
||||
// NNPA Hardware defined values for Function Specific Parameters
|
||||
typedef enum nnpa_matmul_operations {
|
||||
NNPA_MATMUL_OP_ADDITION = 0,
|
||||
NNPA_MATMUL_OP_COMP_HIGH = 1,
|
||||
NNPA_MATMUL_OP_COMP_NOT_LOW = 2,
|
||||
NNPA_MATMUL_OP_COMP_EQUAL = 3,
|
||||
NNPA_MATMUL_OP_COMP_NOT_EQUAL = 4,
|
||||
NNPA_MATMUL_OP_COMP_NOT_HIGH = 5,
|
||||
NNPA_MATMUL_OP_COMP_LOW = 6,
|
||||
} nnpa_matmul_operations;
|
||||
|
||||
typedef enum nnpa_matmul_bcast_operations {
|
||||
NNPA_MATMUL_BCAST_OP_ADDITION = 0,
|
||||
NNPA_MATMUL_BCAST_OP_COMP_HIGH = 1,
|
||||
NNPA_MATMUL_BCAST_OP_COMP_NOT_LOW = 2,
|
||||
NNPA_MATMUL_BCAST_OP_COMP_EQUAL = 3,
|
||||
NNPA_MATMUL_BCAST_OP_COMP_NOT_EQUAL = 4,
|
||||
NNPA_MATMUL_BCAST_OP_COMP_NOT_HIGH = 5,
|
||||
NNPA_MATMUL_BCAST_OP_COMP_LOW = 6
|
||||
} nnpa_matmul_bcast_operations;
|
||||
|
||||
typedef enum nnpa_softmax_act {
|
||||
NNPA_SOFTMAX_NONE = 0,
|
||||
NNPA_SOFTMAX_LOG = 1
|
||||
} nnpa_softmax_act;
|
||||
|
||||
typedef enum nnpa_reduce_operations {
|
||||
NNPA_REDUCE_OP_MINIMUM = 0,
|
||||
NNPA_REDUCE_OP_MINIMUM_IDX = 1,
|
||||
NNPA_REDUCE_OP_MAXIMUM = 2,
|
||||
NNPA_REDUCE_OP_MAXIMUM_IDX = 3
|
||||
} nnpa_reduce_operations;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// zdnn_query_*() bit-field enums
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// pos is counting from left to right
|
||||
#define MSB_BITMASK(field_size, pos) 1u << ((field_size - 1) - pos)
|
||||
|
||||
typedef enum zdnn_query_datatypes {
|
||||
QUERY_DATATYPE_INTERNAL1 = MSB_BITMASK(16, NNPA_DATATYPE_1),
|
||||
QUERY_DATATYPE_BINARY_FP32 = MSB_BITMASK(16, NNPA_32_BIT_BINARY_FP_SHORT),
|
||||
QUERY_DATATYPE_BINARY_INT8 = MSB_BITMASK(16, NNPA_8_BIT_BINARY_INT),
|
||||
QUERY_DATATYPE_BINARY_INT32 = MSB_BITMASK(16, NNPA_32_BIT_BINARY_INT)
|
||||
} zdnn_query_datatypes;
|
||||
|
||||
typedef enum zdnn_query_layoutfmts {
|
||||
QUERY_LAYOUTFMT_4DFEATURE = MSB_BITMASK(32, NNPA_LAYOUTFMT_4DFEATURE),
|
||||
QUERY_LAYOUTFMT_4DKERNEL = MSB_BITMASK(32, NNPA_LAYOUTFMT_4DKERNEL),
|
||||
QUERY_LAYOUTFMT_4DWEIGHTS = MSB_BITMASK(32, NNPA_LAYOUTFMT_4DWEIGHTS),
|
||||
QUERY_LAYOUTFMT_4DGENERIC = MSB_BITMASK(32, NNPA_LAYOUTFMT_4DGENERIC)
|
||||
} zdnn_query_layoutfmts;
|
||||
|
||||
typedef enum zdnn_query_bfpfmts {
|
||||
QUERY_BFPFMT_TINY = MSB_BITMASK(16, NNPA_BFPFMT_TINY),
|
||||
QUERY_BFPFMT_SHORT = MSB_BITMASK(16, NNPA_BFPFMT_SHORT)
|
||||
} zdnn_query_bfpfmts;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// ZDNN enums
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
typedef enum zdnn_data_types {
|
||||
ZDNN_DLFLOAT16 = NNPA_DATATYPE_1, // 16-bit deep learning format
|
||||
ZDNN_BINARY_FP32 =
|
||||
NNPA_32_BIT_BINARY_FP_SHORT, // 32-bit binary-floating-point format
|
||||
ZDNN_BINARY_INT8 =
|
||||
NNPA_8_BIT_BINARY_INT, // 8-bit signed or unsigned binary integer
|
||||
ZDNN_BINARY_INT32 =
|
||||
NNPA_32_BIT_BINARY_INT, // 32-bit signed or unsigned binary integer
|
||||
INT8 = 251, // 8-bit signed or unsigned binary integer format
|
||||
INT32 = 252, // 32-bit signed or unsigned binary integer format
|
||||
BFLOAT = 253, // Brain floating point format
|
||||
FP16 = 254, // 16-bit IEEE-754 floating point format
|
||||
FP32 = 255, // 32-bit IEEE-754 floating point format
|
||||
} zdnn_data_types;
|
||||
|
||||
typedef enum zdnn_data_layouts {
|
||||
ZDNN_1D, // 1d tensor
|
||||
ZDNN_2D, // 2d tensor
|
||||
ZDNN_2DS, // represents special 2D tensors required by LSTM/GRU
|
||||
ZDNN_3D, // 3d tensor
|
||||
ZDNN_3DS, // represents special 3D tensors required by
|
||||
// LSTM/GRU/Softmax/Matmul
|
||||
ZDNN_ZRH, // represents (update, reset, hidden) used by GRU
|
||||
ZDNN_4D, // 4d tensor
|
||||
ZDNN_4DS, // represents special 4D tensors required by LSTM/GRU output
|
||||
ZDNN_NHWC, // 4d feature tensor in NHWC
|
||||
ZDNN_NCHW, // 4d feature tensor in NCHW
|
||||
ZDNN_FICO, // represents (forget, input, cell, output) used by LSTM
|
||||
ZDNN_HWCK, // 4d kernel CNN tensor
|
||||
ZDNN_BIDIR_ZRH, // ZRH variant to work with bidirectional LSTM/GRU output
|
||||
ZDNN_BIDIR_FICO // FICO variant to work with bidirectional LSTM/GRU output
|
||||
} zdnn_data_layouts;
|
||||
|
||||
typedef enum zdnn_data_formats {
|
||||
ZDNN_FORMAT_4DFEATURE =
|
||||
NNPA_LAYOUTFMT_4DFEATURE, // tensor in zAIU data layout format 0
|
||||
ZDNN_FORMAT_4DKERNEL =
|
||||
NNPA_LAYOUTFMT_4DKERNEL, // tensor in zAIU data layout format 1
|
||||
ZDNN_FORMAT_4DWEIGHTS =
|
||||
NNPA_LAYOUTFMT_4DWEIGHTS, // tensor in zAIU data layout format 2
|
||||
ZDNN_FORMAT_4DGENERIC =
|
||||
NNPA_LAYOUTFMT_4DGENERIC, // tensor in zAIU data layout 31
|
||||
} zdnn_data_formats;
|
||||
|
||||
typedef enum zdnn_quantized_transform_types {
|
||||
QUANTIZED_DLFLOAT16 = 0, // quantized dlfloat16
|
||||
QUANTIZED_INT8 = 1, // quantized int8
|
||||
QUANTIZED_WEIGHTS_INT8 = 2 // quantized weights
|
||||
} zdnn_quantized_transform_types;
|
||||
|
||||
// Supported padding types for use in pooling functions
|
||||
typedef enum zdnn_pool_padding {
|
||||
VALID_PADDING = 0,
|
||||
SAME_PADDING = 1
|
||||
} zdnn_pool_padding;
|
||||
|
||||
// Support operations for use in matmul functions
|
||||
typedef enum zdnn_matmul_ops {
|
||||
MATMUL_OP_ADDITION = NNPA_MATMUL_OP_ADDITION,
|
||||
MATMUL_OP_GREATER = NNPA_MATMUL_OP_COMP_HIGH,
|
||||
MATMUL_OP_GREATER_EQUAL = NNPA_MATMUL_OP_COMP_NOT_LOW,
|
||||
MATMUL_OP_EQUAL = NNPA_MATMUL_OP_COMP_EQUAL,
|
||||
MATMUL_OP_NOT_EQUAL = NNPA_MATMUL_OP_COMP_NOT_EQUAL,
|
||||
MATMUL_OP_LESSER_EQUAL = NNPA_MATMUL_OP_COMP_NOT_HIGH,
|
||||
MATMUL_OP_LESSER = NNPA_MATMUL_OP_COMP_LOW
|
||||
} zdnn_matmul_ops;
|
||||
|
||||
// Support operations for use in matmul function
|
||||
typedef enum zdnn_matmul_bcast_ops {
|
||||
MATMUL_BCAST_OP_ADDITION = NNPA_MATMUL_BCAST_OP_ADDITION,
|
||||
MATMUL_BCAST_OP_GREATER = NNPA_MATMUL_BCAST_OP_COMP_HIGH,
|
||||
MATMUL_BCAST_OP_GREATER_EQUAL = NNPA_MATMUL_BCAST_OP_COMP_NOT_LOW,
|
||||
MATMUL_BCAST_OP_EQUAL = NNPA_MATMUL_BCAST_OP_COMP_EQUAL,
|
||||
MATMUL_BCAST_OP_NOT_EQUAL = NNPA_MATMUL_BCAST_OP_COMP_NOT_EQUAL,
|
||||
MATMUL_BCAST_OP_LESSER_EQUAL = NNPA_MATMUL_BCAST_OP_COMP_NOT_HIGH,
|
||||
MATMUL_BCAST_OP_LESSER = NNPA_MATMUL_BCAST_OP_COMP_LOW
|
||||
|
||||
} zdnn_matmul_bcast_ops;
|
||||
|
||||
typedef enum zdnn_softmax_act {
|
||||
SOFTMAX_ACT_NONE = NNPA_SOFTMAX_NONE,
|
||||
SOFTMAX_ACT_LOG = NNPA_SOFTMAX_LOG
|
||||
} zdnn_softmax_act;
|
||||
|
||||
typedef enum zdnn_conv2d_act {
|
||||
CONV2D_ACT_NONE,
|
||||
CONV2D_ACT_RELU
|
||||
} zdnn_conv2d_act;
|
||||
|
||||
// Support operations for use in reduce functions
|
||||
typedef enum zdnn_reduce_ops {
|
||||
REDUCE_OP_MINIMUM = NNPA_REDUCE_OP_MINIMUM,
|
||||
REDUCE_OP_MINIMUM_IDX = NNPA_REDUCE_OP_MINIMUM_IDX,
|
||||
REDUCE_OP_MAXIMUM = NNPA_REDUCE_OP_MAXIMUM,
|
||||
REDUCE_OP_MAXIMUM_IDX = NNPA_REDUCE_OP_MAXIMUM_IDX
|
||||
} zdnn_reduce_ops;
|
||||
|
||||
typedef enum zdnn_moments_bessel {
|
||||
MOMENTS_BESSEL_POPULATION,
|
||||
MOMENTS_BESSEL_SAMPLE,
|
||||
} zdnn_moments_bessel;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Structs
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// describes general pre-transformed or transformed information (e.g. shape) of
|
||||
// a tensor
|
||||
typedef struct zdnn_tensor_desc {
|
||||
zdnn_data_layouts layout; // data layout
|
||||
zdnn_data_formats format; // internal use only
|
||||
zdnn_data_types type; // data type
|
||||
uint32_t dim4; // number of elements in outermost dimension
|
||||
uint32_t dim3; // ... outer dimension
|
||||
uint32_t dim2; // ... inner dimension
|
||||
uint32_t dim1; // number of elements in innermost dimension
|
||||
} zdnn_tensor_desc;
|
||||
|
||||
// struct for describing a ztensor
|
||||
typedef struct zdnn_ztensor {
|
||||
zdnn_tensor_desc
|
||||
*pre_transformed_desc; // tensor's shape information before transformation
|
||||
zdnn_tensor_desc *transformed_desc; // transformed tensor's shape information
|
||||
uint64_t buffer_size; // tensor size in bytes
|
||||
void *buffer; // pointer to the tensor in memory
|
||||
bool is_transformed; // indicator if data in buffer has been transformed
|
||||
char reserved[3]; // not currently used, should contain zeros.
|
||||
float rec_scale; // the scale factor for quantization, stored as reciprocal
|
||||
float offset; // the offset for quantization
|
||||
char reserved2[20]; // not currently used, should contain zeros.
|
||||
} zdnn_ztensor;
|
||||
|
||||
#define ZDNN_VERSION "1.2.0"
|
||||
#define ZDNN_VERNUM 0x010200 // 0x[major][minor][patch]
|
||||
#define ZDNN_VER_MAJOR 1
|
||||
#define ZDNN_VER_MINOR 2
|
||||
#define ZDNN_VER_PATCH 0
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// External Tensor Functions
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Concatenation information is encoded into a 32-bit word:
|
||||
// [RNN_TYPE: 8][PREV_LAYER_TYPE: 8][USAGE: 8][8]
|
||||
|
||||
typedef uint32_t zdnn_concat_info;
|
||||
|
||||
#define BITSHIFT_RNN_TYPE 24
|
||||
#define BITSHIFT_PREV_LAYER 16
|
||||
#define BITSHIFT_USAGE 8
|
||||
|
||||
#define RNN_TYPE_LSTM (0 << BITSHIFT_RNN_TYPE)
|
||||
#define RNN_TYPE_GRU (1 << BITSHIFT_RNN_TYPE)
|
||||
|
||||
#define PREV_LAYER_UNI (0 << BITSHIFT_PREV_LAYER)
|
||||
#define PREV_LAYER_NONE PREV_LAYER_UNI
|
||||
#define PREV_LAYER_BIDIR (1 << BITSHIFT_PREV_LAYER)
|
||||
|
||||
#define USAGE_WEIGHTS (0 << BITSHIFT_USAGE)
|
||||
#define USAGE_HIDDEN_WEIGHTS (1 << BITSHIFT_USAGE)
|
||||
#define USAGE_BIASES (2 << BITSHIFT_USAGE)
|
||||
#define USAGE_HIDDEN_BIASES (3 << BITSHIFT_USAGE)
|
||||
|
||||
#define CONCAT_RNN_TYPE(info) (info & (0xFFu << BITSHIFT_RNN_TYPE))
|
||||
#define CONCAT_PREV_LAYER(info) (info & (0xFFu << BITSHIFT_PREV_LAYER))
|
||||
#define CONCAT_USAGE(info) (info & (0xFFu << BITSHIFT_USAGE))
|
||||
|
||||
void zdnn_init_pre_transformed_desc(zdnn_data_layouts layout,
|
||||
zdnn_data_types type,
|
||||
zdnn_tensor_desc *pre_tfrmd_desc, ...);
|
||||
|
||||
zdnn_status
|
||||
zdnn_generate_transformed_desc(const zdnn_tensor_desc *pre_tfrmd_desc,
|
||||
zdnn_tensor_desc *tfrmd_desc);
|
||||
|
||||
zdnn_status zdnn_generate_quantized_transformed_desc(
|
||||
const zdnn_tensor_desc *pre_tfrmd_desc,
|
||||
zdnn_quantized_transform_types transform_type,
|
||||
zdnn_tensor_desc *tfrmd_desc);
|
||||
|
||||
zdnn_status zdnn_generate_transformed_desc_concatenated(
|
||||
const zdnn_tensor_desc *pre_tfrmd_desc, zdnn_concat_info info,
|
||||
zdnn_tensor_desc *tfrmd_desc);
|
||||
|
||||
zdnn_status zdnn_allochelper_ztensor(zdnn_ztensor *ztensor);
|
||||
zdnn_status zdnn_free_ztensor_buffer(const zdnn_ztensor *ztensor);
|
||||
|
||||
void zdnn_init_ztensor(zdnn_tensor_desc *pre_tfrmd_desc,
|
||||
zdnn_tensor_desc *tfrmd_desc, zdnn_ztensor *output);
|
||||
|
||||
void zdnn_init_quantized_ztensor(zdnn_tensor_desc *pre_tfrmd_desc,
|
||||
zdnn_tensor_desc *tfrmd_desc, float scale,
|
||||
float offset, zdnn_ztensor *output);
|
||||
|
||||
zdnn_status zdnn_init_ztensor_with_malloc(zdnn_tensor_desc *pre_tfrmd_desc,
|
||||
zdnn_tensor_desc *tfrmd_desc,
|
||||
zdnn_ztensor *output);
|
||||
|
||||
zdnn_status zdnn_init_quantized_ztensor_with_malloc(
|
||||
zdnn_tensor_desc *pre_tfrmd_desc, zdnn_tensor_desc *tfrmd_desc, float scale,
|
||||
float offset, zdnn_ztensor *output);
|
||||
|
||||
bool zdnn_is_quantized_ztensor(zdnn_ztensor *ztensor);
|
||||
|
||||
void zdnn_reset_ztensor(zdnn_ztensor *ztensor);
|
||||
|
||||
uint64_t zdnn_getsize_ztensor(const zdnn_tensor_desc *tfrmd_desc);
|
||||
|
||||
zdnn_status zdnn_getrange_ztensor(const zdnn_ztensor *ztensor, float *min,
|
||||
float *max);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// External Query Functions
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
bool zdnn_is_nnpa_installed();
|
||||
bool zdnn_is_nnpa_function_installed(int count, ...);
|
||||
bool zdnn_is_nnpa_parmblk_fmt_installed(int count, ...);
|
||||
bool zdnn_is_nnpa_datatype_installed(uint16_t types_bitmask);
|
||||
bool zdnn_is_nnpa_layout_fmt_installed(uint32_t layout_bitmask);
|
||||
bool zdnn_is_nnpa_conversion_installed(nnpa_data_type type,
|
||||
uint16_t format_bitmask);
|
||||
|
||||
uint32_t zdnn_get_nnpa_max_dim_idx_size();
|
||||
uint32_t zdnn_get_max_for_dim(uint8_t dimension);
|
||||
uint64_t zdnn_get_nnpa_max_tensor_size();
|
||||
|
||||
zdnn_status zdnn_refresh_nnpa_query_result();
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Versioning Functions
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
bool zdnn_is_version_runnable(uint32_t ver_num);
|
||||
uint32_t zdnn_get_max_runnable_version();
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// External Elementwise Operations
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
zdnn_status zdnn_add(const zdnn_ztensor *input_a, const zdnn_ztensor *input_b,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_sub(const zdnn_ztensor *input_a, const zdnn_ztensor *input_b,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_mul(const zdnn_ztensor *input_a, const zdnn_ztensor *input_b,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_div(const zdnn_ztensor *input_a, const zdnn_ztensor *input_b,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_min(const zdnn_ztensor *input_a, const zdnn_ztensor *input_b,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_max(const zdnn_ztensor *input_a, const zdnn_ztensor *input_b,
|
||||
zdnn_ztensor *output);
|
||||
|
||||
zdnn_status zdnn_log(const zdnn_ztensor *input, zdnn_ztensor *output);
|
||||
zdnn_status zdnn_exp(const zdnn_ztensor *input, zdnn_ztensor *output);
|
||||
zdnn_status zdnn_sqrt(const zdnn_ztensor *input, zdnn_ztensor *output);
|
||||
zdnn_status zdnn_invsqrt(const zdnn_ztensor *input, float epsilon,
|
||||
zdnn_ztensor *output);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// External Activation Operations
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
zdnn_status zdnn_relu(const zdnn_ztensor *input, const void *clipping_value,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_leaky_relu(const zdnn_ztensor *input,
|
||||
const void *clipping_value, float adjustment_factor,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_tanh(const zdnn_ztensor *input, zdnn_ztensor *output);
|
||||
zdnn_status zdnn_sigmoid(const zdnn_ztensor *input, zdnn_ztensor *output);
|
||||
zdnn_status zdnn_softmax(const zdnn_ztensor *input, void *save_area,
|
||||
zdnn_softmax_act act_func, zdnn_ztensor *output);
|
||||
zdnn_status zdnn_softmax_mask(const zdnn_ztensor *input, void *save_area,
|
||||
zdnn_softmax_act act_func, uint32_t softmax_mask,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_gelu(const zdnn_ztensor *input, zdnn_ztensor *output);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Recurrent Neural Network (RNN) Operations
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
typedef enum lstm_gru_direction { FWD, BWD, BIDIR } lstm_gru_direction;
|
||||
|
||||
zdnn_status zdnn_lstm(const zdnn_ztensor *input, const zdnn_ztensor *h0,
|
||||
const zdnn_ztensor *c0, const zdnn_ztensor *weights,
|
||||
const zdnn_ztensor *biases,
|
||||
const zdnn_ztensor *hidden_weights,
|
||||
const zdnn_ztensor *hidden_biases,
|
||||
lstm_gru_direction direction, void *work_area,
|
||||
zdnn_ztensor *hn_output, zdnn_ztensor *cf_output);
|
||||
zdnn_status zdnn_gru(const zdnn_ztensor *input, const zdnn_ztensor *h0,
|
||||
const zdnn_ztensor *weights, const zdnn_ztensor *biases,
|
||||
const zdnn_ztensor *hidden_weights,
|
||||
const zdnn_ztensor *hidden_biases,
|
||||
lstm_gru_direction direction, void *work_area,
|
||||
zdnn_ztensor *hn_output);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Matrix Multiplication Operations
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
zdnn_status zdnn_matmul_op(const zdnn_ztensor *input_a,
|
||||
const zdnn_ztensor *input_b,
|
||||
const zdnn_ztensor *input_c, zdnn_matmul_ops op_type,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_matmul_bcast_op(const zdnn_ztensor *input_a,
|
||||
const zdnn_ztensor *input_b,
|
||||
const zdnn_ztensor *input_c,
|
||||
zdnn_matmul_bcast_ops op_type,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_matmul_transpose_op(const zdnn_ztensor *input_a,
|
||||
const zdnn_ztensor *input_b,
|
||||
const zdnn_ztensor *input_c,
|
||||
bool transpose_a, bool transpose_b,
|
||||
zdnn_matmul_ops op_type,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_quantized_matmul_op(
|
||||
const zdnn_ztensor *input_a, const zdnn_ztensor *input_b,
|
||||
const zdnn_ztensor *input_c, zdnn_matmul_ops op_type, const int8_t clip_min,
|
||||
const int8_t clip_max, const bool disable_clipping, const bool dequantize,
|
||||
const bool pre_computed, void *work_area, zdnn_ztensor *output);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// External Norm Operations
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
zdnn_status zdnn_batchnorm(const zdnn_ztensor *input_a,
|
||||
const zdnn_ztensor *input_b,
|
||||
const zdnn_ztensor *input_c, zdnn_ztensor *output);
|
||||
zdnn_status zdnn_norm(const zdnn_ztensor *input_a, const zdnn_ztensor *input_b,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_moments(const zdnn_ztensor *input,
|
||||
zdnn_moments_bessel bessel_correction_type,
|
||||
zdnn_ztensor *output_a, zdnn_ztensor *output_b);
|
||||
zdnn_status zdnn_layernorm(const zdnn_ztensor *input_a,
|
||||
const zdnn_ztensor *input_b,
|
||||
const zdnn_ztensor *input_c, const float beta_value,
|
||||
const float gamma_value, const float epsilon_value,
|
||||
zdnn_ztensor *output);
|
||||
zdnn_status zdnn_meanreduce2d(const zdnn_ztensor *input, zdnn_ztensor *output);
|
||||
|
||||
zdnn_status zdnn_reduce(const zdnn_ztensor *input, void *save_area,
|
||||
zdnn_reduce_ops op_type, zdnn_ztensor *output);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// External Pool Operations
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
zdnn_status zdnn_avgpool2d(const zdnn_ztensor *input,
|
||||
zdnn_pool_padding padding_type,
|
||||
uint32_t kernel_height, uint32_t kernel_width,
|
||||
uint32_t stride_height, uint32_t stride_width,
|
||||
zdnn_ztensor *output);
|
||||
|
||||
zdnn_status zdnn_maxpool2d(const zdnn_ztensor *input,
|
||||
zdnn_pool_padding padding_type,
|
||||
uint32_t kernel_height, uint32_t kernel_width,
|
||||
uint32_t stride_height, uint32_t stride_width,
|
||||
zdnn_ztensor *output);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// External Convolution Operations
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
zdnn_status zdnn_conv2d(const zdnn_ztensor *input, const zdnn_ztensor *kernel,
|
||||
const zdnn_ztensor *bias,
|
||||
zdnn_pool_padding padding_type, uint32_t stride_height,
|
||||
uint32_t stride_width, zdnn_conv2d_act act_func,
|
||||
const void *clipping_value, zdnn_ztensor *output);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// External Tensor Transform Operations
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
zdnn_status zdnn_transform_ztensor(zdnn_ztensor *ztensor, ...);
|
||||
|
||||
zdnn_status zdnn_transform_ztensor_with_saturation(zdnn_ztensor *ztensor, ...);
|
||||
|
||||
zdnn_status zdnn_transform_quantized_ztensor(zdnn_ztensor *ztensor,
|
||||
bool saturation_control,
|
||||
int8_t clip_min, int8_t clip_max,
|
||||
const void *data);
|
||||
|
||||
zdnn_status zdnn_transform_origtensor(const zdnn_ztensor *ztensor,
|
||||
void *out_buf);
|
||||
|
||||
zdnn_status zdnn_reshape_ztensor(const zdnn_ztensor *src, zdnn_ztensor *dest);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// External Version Related Functions
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
char *zdnn_get_library_version_str();
|
||||
uint32_t zdnn_get_library_version();
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// zDNN Status Related Functions
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
const char *zdnn_get_status_message(zdnn_status status);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// zDNN Data Type Limit Functions
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
zdnn_status zdnn_get_max_limit(zdnn_data_types transformed_type,
|
||||
zdnn_data_types pre_transformed_type,
|
||||
void *limit);
|
||||
zdnn_status zdnn_get_min_limit(zdnn_data_types transformed_type,
|
||||
zdnn_data_types pre_transformed_type,
|
||||
void *limit);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#endif /* ZDNN_ZDNN_H_ */
|
||||
Reference in New Issue
Block a user