vulkan: Reuse conversion results in prealloc_y (#15410)

* vulkan: Reuse conversion results in prealloc_y

Cache the pipeline and tensor that were most recently used to fill prealloc_y,
and skip the conversion if the current pipeline/tensor match.

* don't use shared pointer for prealloc_y_last_pipeline_used
This commit is contained in:
Jeff Bolz
2025-08-21 09:55:00 -05:00
committed by GitHub
parent 9ad5e60dba
commit 96452a3fa4
2 changed files with 94 additions and 23 deletions

View File

@@ -3098,9 +3098,10 @@ struct test_mul_mat : public test_case {
const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
const std::array<int64_t, 4> per; // permutation of dimensions
const bool v; // whether a and b are non-contiguous views
const uint32_t o; // number of outputs
std::string vars() override {
return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v);
return VARS_TO_STR10(type_a, type_b, m, n, k, bs, nr, per, v, o);
}
double max_nmse_err() override {
@@ -3121,8 +3122,8 @@ struct test_mul_mat : public test_case {
std::array<int64_t, 2> bs = {10, 10},
std::array<int64_t, 2> nr = {2, 2},
std::array<int64_t, 4> per = {0, 1, 2, 3},
bool v = false)
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v) {}
bool v = false, uint32_t o = 1)
: type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr), per(per), v(v), o(o) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
@@ -3186,9 +3187,21 @@ struct test_mul_mat : public test_case {
ggml_tensor * out = ggml_mul_mat(ctx, a, b);
ggml_set_name(out, "out");
for (uint32_t i = 1; i < o; ++i) {
ggml_tensor * out2 = ggml_mul_mat(ctx, a, b);
ggml_set_name(out2, "out2");
out = ggml_add(ctx, out, out2);
}
return out;
}
bool run_whole_graph() override { return o > 1; }
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return ggml_op_name(GGML_OP_MUL_MAT);
}
};
// GGML_OP_MUL_MAT_ID
@@ -3201,9 +3214,10 @@ struct test_mul_mat_id : public test_case {
const int64_t m;
const int64_t n;
const int64_t k;
const uint32_t o; // number of outputs
std::string vars() override {
return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
return VARS_TO_STR9(type_a, type_b, n_mats, n_used, b, m, n, k, o);
}
double max_nmse_err() override {
@@ -3217,9 +3231,9 @@ struct test_mul_mat_id : public test_case {
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int n_mats = 8, int n_used = 2, bool b = false,
int64_t m = 32, int64_t n = 32, int64_t k = 32)
int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1)
: type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
m(m), n(n), k(k) {
m(m), n(n), k(k), o(o) {
GGML_ASSERT(n_used <= n_mats);
}
@@ -3241,6 +3255,13 @@ struct test_mul_mat_id : public test_case {
ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
ggml_set_name(out, "out");
for (uint32_t i = 1; i < o; ++i) {
ggml_tensor * a2 = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
ggml_tensor * out2 = ggml_mul_mat_id(ctx, a2, b, ids);
ggml_set_name(out2, "out2");
out = ggml_add(ctx, out, out2);
}
return out;
}
@@ -3264,6 +3285,13 @@ struct test_mul_mat_id : public test_case {
}
}
}
bool run_whole_graph() override { return o > 1; }
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return ggml_op_name(GGML_OP_MUL_MAT_ID);
}
};
// GGML_OP_OUT_PROD
@@ -5798,6 +5826,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 193, {1, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67, {1, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1, 1}, {1, 1}, {0, 1, 2, 3}, true, 3));
for (auto bs2 : {1,3}) {
for (auto bs : {1,2,4,8}) {
@@ -5826,6 +5855,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {