Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 58 additions & 15 deletions backends/vulkan/test/utils/test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,33 +357,76 @@ void record_matmul_texture3d(
api::Context* context,
api::vTensor& out,
api::vTensor& mat1,
api::vTensor& mat2) {
std::string kernel_name = "matmul_naive";
api::vTensor& mat2,
bool mat2_is_transposed) {
std::string kernel_name =
mat2_is_transposed ? "matmul_transposed_naive" : "matmul_naive";
kernel_name.reserve(kShaderNameReserve);
add_storage_type_suffix(kernel_name, out.storage_type());
add_dtype_suffix(kernel_name, out.dtype());

utils::uvec3 global_wg_size = out.logical_limits();

struct PushConstants {
utils::ivec4 out_sizes;
utils::ivec4 mat1_sizes;
utils::ivec4 mat2_sizes;
utils::ivec3 out_limits;
};

auto make_ivec4 = [](const std::vector<int64_t>& sizes) -> utils::ivec4 {
utils::ivec4 result{1, 1, 1, 1};
for (size_t i = 0; i < std::min(sizes.size(), size_t(4)); ++i) {
result.data[i] = static_cast<int32_t>(sizes[i]);
}
return result;
};

auto make_ivec3 = [](const utils::uvec3& v) -> utils::ivec3 {
return {
static_cast<int32_t>(v.data[0]),
static_cast<int32_t>(v.data[1]),
static_cast<int32_t>(v.data[2])};
};

PushConstants push_constants{
make_ivec4(out.sizes()),
make_ivec4(mat1.sizes()),
make_ivec4(mat2.sizes()),
make_ivec3(out.logical_limits()),
};

vkapi::PipelineBarrier pipeline_barrier{};
api::context()->submit_compute_job(

vkapi::SpecVarList specialization_constants = {
out.hashed_layout(), mat1.hashed_layout(), mat2.hashed_layout()};

utils::uvec3 local_wg_size = {8, 8, 1};

vkapi::DescriptorSet descriptor_set = api::context()->get_descriptor_set(
VK_KERNEL_FROM_STR(kernel_name),
pipeline_barrier,
global_wg_size,
{8, 8, 1},
{out.hashed_layout(), mat1.hashed_layout(), mat2.hashed_layout()},
VK_NULL_HANDLE,
utils::WorkgroupSize(local_wg_size),
specialization_constants,
sizeof(push_constants));

descriptor_set.bind(
0,
out.image(
pipeline_barrier,
vkapi::PipelineStage::COMPUTE,
vkapi::MemoryAccessType::WRITE),
mat1.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE),
mat2.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE),
out.sizes_ubo(),
out.logical_limits_ubo(),
mat1.sizes_ubo(),
mat2.sizes_ubo());
vkapi::MemoryAccessType::WRITE));
descriptor_set.bind(
1, mat1.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE));
descriptor_set.bind(
2, mat2.image(pipeline_barrier, vkapi::PipelineStage::COMPUTE));

api::context()->register_shader_dispatch(
descriptor_set,
pipeline_barrier,
VK_KERNEL_FROM_STR(kernel_name),
global_wg_size,
&push_constants,
sizeof(push_constants));
}

//
Expand Down
3 changes: 2 additions & 1 deletion backends/vulkan/test/utils/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ void record_matmul_texture3d(
vkcompute::api::Context* context,
vkcompute::api::vTensor& out,
vkcompute::api::vTensor& mat1,
vkcompute::api::vTensor& mat2);
vkcompute::api::vTensor& mat2,
bool mat2_is_transposed = false);

//
// Input & Output Utilities
Expand Down
5 changes: 3 additions & 2 deletions backends/vulkan/test/vulkan_compute_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ TEST_F(VulkanComputeAPITest, tensor_no_copy_transpose_test) {
std::vector<int64_t> mat2_sizes = {N, K};
std::vector<int64_t> out_sizes = {M, N};

for (const auto storage_type : {utils::kTexture3D, utils::kBuffer}) {
for (const auto storage_type : {utils::kBuffer}) {
vTensor mat1 = vTensor(
context(),
mat1_sizes,
Expand Down Expand Up @@ -876,7 +876,8 @@ TEST_F(VulkanComputeAPITest, tensor_no_copy_transpose_test) {
fill_vtensor(mat2, mat2_data);

if (storage_type == utils::kTexture3D) {
record_matmul_texture3d(context(), out, mat1, mat2_t);
record_matmul_texture3d(
context(), out, mat1, mat2_t, /*mat2_is_transposed=*/true);
} else {
record_reference_matmul(context(), out, mat1, mat2_t);
}
Expand Down
Loading