@@ -357,33 +357,76 @@ void record_matmul_texture3d(
357357 api::Context* context,
358358 api::vTensor& out,
359359 api::vTensor& mat1,
360- api::vTensor& mat2) {
361- std::string kernel_name = " matmul_naive" ;
360+ api::vTensor& mat2,
361+ bool mat2_is_transposed) {
362+ std::string kernel_name =
363+ mat2_is_transposed ? " matmul_transposed_naive" : " matmul_naive" ;
362364 kernel_name.reserve (kShaderNameReserve );
363365 add_storage_type_suffix (kernel_name, out.storage_type ());
364366 add_dtype_suffix (kernel_name, out.dtype ());
365367
366368 utils::uvec3 global_wg_size = out.logical_limits ();
367369
370+ struct PushConstants {
371+ utils::ivec4 out_sizes;
372+ utils::ivec4 mat1_sizes;
373+ utils::ivec4 mat2_sizes;
374+ utils::ivec3 out_limits;
375+ };
376+
377+ auto make_ivec4 = [](const std::vector<int64_t >& sizes) -> utils::ivec4 {
378+ utils::ivec4 result{1 , 1 , 1 , 1 };
379+ for (size_t i = 0 ; i < std::min (sizes.size (), size_t (4 )); ++i) {
380+ result.data [i] = static_cast <int32_t >(sizes[i]);
381+ }
382+ return result;
383+ };
384+
385+ auto make_ivec3 = [](const utils::uvec3& v) -> utils::ivec3 {
386+ return {
387+ static_cast <int32_t >(v.data [0 ]),
388+ static_cast <int32_t >(v.data [1 ]),
389+ static_cast <int32_t >(v.data [2 ])};
390+ };
391+
392+ PushConstants push_constants{
393+ make_ivec4 (out.sizes ()),
394+ make_ivec4 (mat1.sizes ()),
395+ make_ivec4 (mat2.sizes ()),
396+ make_ivec3 (out.logical_limits ()),
397+ };
398+
368399 vkapi::PipelineBarrier pipeline_barrier{};
369- api::context ()->submit_compute_job (
400+
401+ vkapi::SpecVarList specialization_constants = {
402+ out.hashed_layout (), mat1.hashed_layout (), mat2.hashed_layout ()};
403+
404+ utils::uvec3 local_wg_size = {8 , 8 , 1 };
405+
406+ vkapi::DescriptorSet descriptor_set = api::context ()->get_descriptor_set (
370407 VK_KERNEL_FROM_STR (kernel_name),
371- pipeline_barrier ,
372- global_wg_size ,
373- { 8 , 8 , 1 },
374- {out. hashed_layout (), mat1. hashed_layout (), mat2. hashed_layout ()},
375- VK_NULL_HANDLE,
408+ utils::WorkgroupSize (local_wg_size) ,
409+ specialization_constants ,
410+ sizeof (push_constants));
411+
412+ descriptor_set. bind (
376413 0 ,
377414 out.image (
378415 pipeline_barrier,
379416 vkapi::PipelineStage::COMPUTE,
380- vkapi::MemoryAccessType::WRITE),
381- mat1.image (pipeline_barrier, vkapi::PipelineStage::COMPUTE),
382- mat2.image (pipeline_barrier, vkapi::PipelineStage::COMPUTE),
383- out.sizes_ubo (),
384- out.logical_limits_ubo (),
385- mat1.sizes_ubo (),
386- mat2.sizes_ubo ());
417+ vkapi::MemoryAccessType::WRITE));
418+ descriptor_set.bind (
419+ 1 , mat1.image (pipeline_barrier, vkapi::PipelineStage::COMPUTE));
420+ descriptor_set.bind (
421+ 2 , mat2.image (pipeline_barrier, vkapi::PipelineStage::COMPUTE));
422+
423+ api::context ()->register_shader_dispatch (
424+ descriptor_set,
425+ pipeline_barrier,
426+ VK_KERNEL_FROM_STR (kernel_name),
427+ global_wg_size,
428+ &push_constants,
429+ sizeof (push_constants));
387430}
388431
389432//
0 commit comments