我想使用 SASS 指令,但 (AFAICT) 无法通过 PTX 指令使用。即,假设它是:
HMMA.16816.F16
- 半精度数据的扭曲宽度矩阵乘法加法,形状 M=16、N=8、K=16 (IIANM)。
CUDA 12.4 的CUDA PTX ISA 指南指出第 9.7.13.3 节中,在 FP16 精度下,我们只有形状 (M,N,K) 为 (16, 16, 16) 或(32, 8, 16) 或 (8, 32, 16) - 不小。但是第 9.7.13.1 节说较小的矩阵形状 - (16, 8, 16)、(16, 8, 8) 和 (8, 8, 4)。
尝试使用与这些较小形状相对应的内在函数,例如:
__hmma_m16n8k16_ld_a
导致错误:
mma-smaller.hpp(86): error: identifier "__hmma_m16n8k16_ld_a" is undefined
__hmma_m16n8k16_ld_a((int*)&a, (const int*)p, ldm, 0);
^
那么 PTX 是否支持这些形状?
TL;DR:您可以通过适当选择 PTX 级
mma
指令(不是 wmma
)来发出这样的 SASS 指令,但据我所知,目前还没有相应的 C++ 内在函数可以执行此操作。
更长: 让我们从一些一般背景开始来理清其中一些想法。 mma 类指令主要用于练习张量核心单元,这些单元提供硬件加速的矩阵-矩阵乘法运算。
wmma.mma
、mma
和 wgmma.mma_async
。wmma.mma
指令的特点是它们还具有相应的矩阵加载和存储指令 - 它们不直接公开每线程寄存器存储占用空间。另一方面,mma
指令直接采用 PTX 寄存器输入/输出。wmma
风格的操作被公开——这是可能的张量核心操作的子集,并且该子集对应于PTXwmma.mma
指令,并且该子集也通过以下事实进行区分使用矩阵加载/存储函数,而不是直接寄存器操作。wmma:mma_sync(...)
。 C++ 编程指南中没有记录类似 __hmma_m16n8k16_ld_a
PTX (8.3) 不涵盖较小形状的 WMMA 指令吗?
是的,您可以使用 PTX 发出 16x8x16 (M,N,K) 16 位浮点张量核运算。它不能直接使用 C++ 内在函数来完成,并且在 PTX 中我不会使用
wmma.mma
指令,我会使用 这个 mma PTX 指令 - mma.m16n8k16 。有关 PTX 寄存器布局的详细说明请参见here。 here 给出了指令框架示例。该链接后面的“目标 ISA 注释”部分提供了硬件支持信息。值得注意的是:
.f16 浮点型 mma 运算,.m16n8k16 形状需要 sm_80 或更高。
这是一个完整的示例(对我所描述的内容进行了修改此处):
# cat t153.cu
#include <mma.h>
#include <cuda_fp16.h>
#include <iostream>
#include <stdio.h>
__global__ void mma_fp16_acc_fp32(float *out) {
float c[4] = {0., 0., 0., 0.};
float d[4] = {0., 0., 0., 0.};
half a[8] = {1., 1., 1., 1., 1., 1., 1., 1.};
half b[4] = {1., 1., 1., 1.};
// the above would set our input matrices to all 1
// now lets modify some values
if (threadIdx.x%4 == 0) {
// set the first column of A to be 0, 1, 2, 3, ... 15
a[0] = threadIdx.x/4; a[2] = threadIdx.x/4 + 8;
// set the second row of B to 3,3,3, ... 3
b[1] = 3;}
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
float const *C = reinterpret_cast<float const *>(&c);
float *D = reinterpret_cast<float *>(&d);
asm(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
:
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(B[0]), "r"(B[1]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])
);
memcpy(out+threadIdx.x*2, D, 8);
memcpy(out+8*8+threadIdx.x*2, D+2, 8);
}
int main() {
float* h_C = (float*)malloc(16*8*sizeof(float));
float* d_C;
cudaMalloc(&d_C, 16*8*sizeof(float));
mma_fp16_acc_fp32<<<1, 32>>>(d_C);
cudaDeviceSynchronize();
cudaMemcpy(h_C, d_C, 16*8*sizeof(float), cudaMemcpyDeviceToHost);
for (int i = 0; i < 16; i++){
for (int j = 0; j < 8; j++) std::cout << h_C[i*8+j] << " ";
std::cout << std::endl;}
}
# nvcc -o t153 t153.cu -arch=sm_89
# compute-sanitizer ./t153
========= COMPUTE-SANITIZER
17 17 17 17 17 17 17 17
18 18 18 18 18 18 18 18
19 19 19 19 19 19 19 19
20 20 20 20 20 20 20 20
21 21 21 21 21 21 21 21
22 22 22 22 22 22 22 22
23 23 23 23 23 23 23 23
24 24 24 24 24 24 24 24
25 25 25 25 25 25 25 25
26 26 26 26 26 26 26 26
27 27 27 27 27 27 27 27
28 28 28 28 28 28 28 28
29 29 29 29 29 29 29 29
30 30 30 30 30 30 30 30
31 31 31 31 31 31 31 31
32 32 32 32 32 32 32 32
========= ERROR SUMMARY: 0 errors
#
如我提供的链接所示,这些张量核心运算计算:
D = A*B+C
在上面的示例中,我选择使用/声明 A 和 B 为 16 位浮点,而 C 和 D 为 32 位浮点。
如果我们反汇编上面构建的代码,我们会观察到以下内容,表明正在使用 SASS 级别的张量核心操作:
# cuobjdump -sass ./t153
Fatbin elf code:
================
arch = sm_89
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_89
Fatbin elf code:
================
arch = sm_89
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_89
Function : _Z17mma_fp16_acc_fp32Pf
.headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM89 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM89)"
/*0000*/ IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ; /* 0x00000a00ff017624 */
/* 0x000fc400078e00ff */
/*0010*/ S2R R9, SR_TID.X ; /* 0x0000000000097919 */
/* 0x000e220000002100 */
/*0020*/ HADD2 R3, -RZ.H0_H0, 1, 1 ; /* 0x3c003c00ff037430 */
/* 0x000fe20000000900 */
/*0030*/ IMAD.MOV.U32 R5, RZ, RZ, 0x3c00 ; /* 0x00003c00ff057424 */
/* 0x000fe200078e00ff */
/*0040*/ ULDC.64 UR4, c[0x0][0x118] ; /* 0x0000460000047ab9 */
/* 0x000fe20000000a00 */
/*0050*/ IMAD.MOV.U32 R14, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff0e7424 */
/* 0x000fe400078e00ff */
/*0060*/ PRMT R0, R3.reuse, 0x7610, R0 ; /* 0x0000761003007816 */
/* 0x040fe20000000000 */
/*0070*/ IMAD.MOV.U32 R15, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff0f7424 */
/* 0x000fe200078e00ff */
/*0080*/ PRMT R2, R3, 0x7610, R2 ; /* 0x0000761003027816 */
/* 0x000fe40000000002 */
/*0090*/ LOP3.LUT P0, RZ, R9, 0x3, RZ, 0xc0, !PT ; /* 0x0000000309ff7812 */
/* 0x001fda000780c0ff */
/*00a0*/ @!P0 SHF.R.U32.HI R4, RZ, 0x2, R9 ; /* 0x00000002ff048819 */
/* 0x000fe20000011609 */
/*00b0*/ @!P0 I2F.F16 R3, 0x3 ; /* 0x0000000300038906 */
/* 0x000fe60000200c00 */
/*00c0*/ @!P0 IADD3 R8, R4, 0x8, RZ ; /* 0x0000000804088810 */
/* 0x000fca0007ffe0ff */
/*00d0*/ @!P0 I2F.F16.U32 R0, R4 ; /* 0x0000000400008306 */
/* 0x000e300000200800 */
/*00e0*/ @!P0 I2F.F16.U32 R2, R8 ; /* 0x0000000800028306 */
/* 0x000e620000200800 */
/*00f0*/ PRMT R0, R0, 0x5410, R5 ; /* 0x0000541000007816 */
/* 0x001fe20000000005 */
/*0100*/ IMAD.MOV.U32 R5, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff057424 */
/* 0x000fc600078e00ff */
/*0110*/ PRMT R4, R0, 0x5432, R3 ; /* 0x0000543200047816 */
/* 0x000fe20000000003 */
/*0120*/ IMAD.MOV.U32 R12, RZ, RZ, R0.reuse ; /* 0x000000ffff0c7224 */
/* 0x100fe400078e0000 */
/*0130*/ IMAD.MOV.U32 R3, RZ, RZ, 0x4 ; /* 0x00000004ff037424 */
/* 0x000fe200078e00ff */
/*0140*/ PRMT R13, R2, 0x7610, R0 ; /* 0x00007610020d7816 */
/* 0x002fe20000000000 */
/*0150*/ IMAD.SHL.U32 R2, R9, 0x2, RZ ; /* 0x0000000209027824 */
/* 0x000fc800078e00ff */
/*0160*/ IMAD.WIDE.U32 R2, R2, R3, c[0x0][0x160] ; /* 0x0000580002027625 */
/* 0x000fe400078e0003 */
/*0170*/ HMMA.16816.F32 R4, R12, R4, RZ ; /* 0x000000040c04723c */
/* 0x000f5e00000018ff */
/*0180*/ NOP ; /* 0x0000000000007918 */
/* 0x000fd00000000000 */
/*0190*/ STG.E.U8 [R2.64+0x4], R5 ; /* 0x0000040502007986 */
/* 0x0201e2000c101104 */
/*01a0*/ SHF.R.U32.HI R15, RZ, 0x18, R5.reuse ; /* 0x00000018ff0f7819 */
/* 0x100fe40000011605 */
/*01b0*/ SHF.R.U32.HI R17, RZ, 0x10, R5.reuse ; /* 0x00000010ff117819 */
/* 0x100fe20000011605 */
/*01c0*/ STG.E.U8 [R2.64], R4 ; /* 0x0000000402007986 */
/* 0x000fe2000c101104 */
/*01d0*/ SHF.R.U32.HI R19, RZ, 0x8, R5 ; /* 0x00000008ff137819 */
/* 0x000fe40000011605 */
/*01e0*/ SHF.R.U32.HI R9, RZ, 0x18, R4.reuse ; /* 0x00000018ff097819 */
/* 0x100fe20000011604 */
/*01f0*/ STG.E.U8 [R2.64+0x100], R6 ; /* 0x0001000602007986 */
/* 0x000fe2000c101104 */
/*0200*/ SHF.R.U32.HI R11, RZ, 0x10, R4.reuse ; /* 0x00000010ff0b7819 */
/* 0x100fe40000011604 */
/*0210*/ SHF.R.U32.HI R13, RZ, 0x8, R4 ; /* 0x00000008ff0d7819 */
/* 0x000fe20000011604 */
/*0220*/ STG.E.U8 [R2.64+0x104], R7 ; /* 0x0001040702007986 */
/* 0x000fe2000c101104 */
/*0230*/ SHF.R.U32.HI R21, RZ, 0x18, R6 ; /* 0x00000018ff157819 */
/* 0x000fc40000011606 */
/*0240*/ SHF.R.U32.HI R23, RZ, 0x10, R6.reuse ; /* 0x00000010ff177819 */
/* 0x100fe20000011606 */
/*0250*/ STG.E.U8 [R2.64+0x3], R9 ; /* 0x0000030902007986 */
/* 0x000fe2000c101104 */
/*0260*/ SHF.R.U32.HI R25, RZ, 0x8, R6 ; /* 0x00000008ff197819 */
/* 0x000fe40000011606 */
/*0270*/ SHF.R.U32.HI R27, RZ, 0x18, R7.reuse ; /* 0x00000018ff1b7819 */
/* 0x100fe20000011607 */
/*0280*/ STG.E.U8 [R2.64+0x2], R11 ; /* 0x0000020b02007986 */
/* 0x000fe2000c101104 */
/*0290*/ SHF.R.U32.HI R29, RZ, 0x10, R7.reuse ; /* 0x00000010ff1d7819 */
/* 0x100fe40000011607 */
/*02a0*/ SHF.R.U32.HI R5, RZ, 0x8, R7 ; /* 0x00000008ff057819 */
/* 0x001fe20000011607 */
/*02b0*/ STG.E.U8 [R2.64+0x1], R13 ; /* 0x0000010d02007986 */
/* 0x000fe8000c101104 */
/*02c0*/ STG.E.U8 [R2.64+0x7], R15 ; /* 0x0000070f02007986 */
/* 0x000fe8000c101104 */
/*02d0*/ STG.E.U8 [R2.64+0x6], R17 ; /* 0x0000061102007986 */
/* 0x000fe8000c101104 */
/*02e0*/ STG.E.U8 [R2.64+0x5], R19 ; /* 0x0000051302007986 */
/* 0x000fe8000c101104 */
/*02f0*/ STG.E.U8 [R2.64+0x103], R21 ; /* 0x0001031502007986 */
/* 0x000fe8000c101104 */
/*0300*/ STG.E.U8 [R2.64+0x102], R23 ; /* 0x0001021702007986 */
/* 0x000fe8000c101104 */
/*0310*/ STG.E.U8 [R2.64+0x101], R25 ; /* 0x0001011902007986 */
/* 0x000fe8000c101104 */
/*0320*/ STG.E.U8 [R2.64+0x107], R27 ; /* 0x0001071b02007986 */
/* 0x000fe8000c101104 */
/*0330*/ STG.E.U8 [R2.64+0x106], R29 ; /* 0x0001061d02007986 */
/* 0x000fe8000c101104 */
/*0340*/ STG.E.U8 [R2.64+0x105], R5 ; /* 0x0001050502007986 */
/* 0x000fe2000c101104 */
/*0350*/ EXIT ; /* 0x000000000000794d */
/* 0x000fea0003800000 */
/*0360*/ BRA 0x360; /* 0xfffffff000007947 */
Fatbin ptx code:
================
arch = sm_89
code version = [8,2]
host = linux
compile_size = 64bit
compressed
#
指示的tensorcore sass指令为HMMA.16816.F32 R4、R12、R4、RZ
如果要查看HMMA.16816.F16,则将C和D矩阵切换为16位浮点数,并相应修改PTX指令。像这样的东西:
# cat t154.cu
#include <mma.h>
#include <cuda_fp16.h>
#include <iostream>
#include <stdio.h>
__global__ void mma_fp16_acc_fp32(float *out) {
half c[4] = {0., 0., 0., 0.};
half d[4] = {0., 0., 0., 0.};
half a[8] = {1., 1., 1., 1., 1., 1., 1., 1.};
half b[4] = {1., 1., 1., 1.};
// the above would set our input matrices to all 1
// now lets modify some values
if (threadIdx.x%4 == 0) {
// set the first column of A to be 0, 1, 2, 3, ... 15
a[0] = threadIdx.x/4; a[2] = threadIdx.x/4 + 8;
// set the second row of B to 3,3,3, ... 3
b[1] = 3;}
unsigned const *A = reinterpret_cast<unsigned const *>(&a);
unsigned const *B = reinterpret_cast<unsigned const *>(&b);
unsigned const *C = reinterpret_cast<unsigned const *>(&c);
unsigned *D = reinterpret_cast<unsigned *>(&d);
asm(
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
"{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n"
: "=r"(D[0]), "=r"(D[1])
:
"r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]),
"r"(B[0]), "r"(B[1]),
"r"(C[0]), "r"(C[1])
);
memcpy(...);
memcpy(...);
}
int main() {
...
}
# nvcc -o t154 t154.cu -arch=sm_89
# cuobjdump -sass ./t154
Fatbin elf code:
================
arch = sm_89
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_89
Fatbin elf code:
================
arch = sm_89
code version = [1,7]
host = linux
compile_size = 64bit
code for sm_89
Function : _Z17mma_fp16_acc_fp32Pf
.headerflags @"EF_CUDA_TEXMODE_UNIFIED EF_CUDA_64BIT_ADDRESS EF_CUDA_SM89 EF_CUDA_VIRTUAL_SM(EF_CUDA_SM89)"
/*0000*/ IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] ; /* 0x00000a00ff017624 */
/* 0x000fc400078e00ff */
/*0010*/ S2R R8, SR_TID.X ; /* 0x0000000000087919 */
/* 0x000e220000002100 */
/*0020*/ HADD2 R3, -RZ.H0_H0, 1, 1 ; /* 0x3c003c00ff037430 */
/* 0x000fe20000000900 */
/*0030*/ IMAD.MOV.U32 R6, RZ, RZ, 0x3c00 ; /* 0x00003c00ff067424 */
/* 0x000fe200078e00ff */
/*0040*/ ULDC.64 UR4, c[0x0][0x118] ; /* 0x0000460000047ab9 */
/* 0x000fe20000000a00 */
/*0050*/ IMAD.MOV.U32 R7, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff077424 */
/* 0x000fe400078e00ff */
/*0060*/ PRMT R0, R3.reuse, 0x7610, R0 ; /* 0x0000761003007816 */
/* 0x040fe20000000000 */
/*0070*/ IMAD.MOV.U32 R14, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff0e7424 */
/* 0x000fe200078e00ff */
/*0080*/ PRMT R2, R3, 0x7610, R2 ; /* 0x0000761003027816 */
/* 0x000fe20000000002 */
/*0090*/ IMAD.MOV.U32 R15, RZ, RZ, 0x3c003c00 ; /* 0x3c003c00ff0f7424 */
/* 0x000fe200078e00ff */
/*00a0*/ LOP3.LUT P0, RZ, R8, 0x3, RZ, 0xc0, !PT ; /* 0x0000000308ff7812 */
/* 0x001fda000780c0ff */
/*00b0*/ @!P0 SHF.R.U32.HI R4, RZ, 0x2, R8 ; /* 0x00000002ff048819 */
/* 0x000fe20000011608 */
/*00c0*/ @!P0 I2F.F16 R3, 0x3 ; /* 0x0000000300038906 */
/* 0x000fe60000200c00 */
/*00d0*/ @!P0 IADD3 R5, R4, 0x8, RZ ; /* 0x0000000804058810 */
/* 0x000fca0007ffe0ff */
/*00e0*/ @!P0 I2F.F16.U32 R0, R4 ; /* 0x0000000400008306 */
/* 0x000e300000200800 */
/*00f0*/ @!P0 I2F.F16.U32 R2, R5 ; /* 0x0000000500028306 */
/* 0x000e620000200800 */
/*0100*/ PRMT R0, R0, 0x5410, R6 ; /* 0x0000541000007816 */
/* 0x001fc80000000006 */
/*0110*/ PRMT R6, R0, 0x5432, R3 ; /* 0x0000543200067816 */
/* 0x000fe20000000003 */
/*0120*/ IMAD.MOV.U32 R12, RZ, RZ, R0.reuse ; /* 0x000000ffff0c7224 */
/* 0x100fe400078e0000 */
/*0130*/ IMAD.MOV.U32 R3, RZ, RZ, 0x4 ; /* 0x00000004ff037424 */
/* 0x000fe200078e00ff */
/*0140*/ PRMT R13, R2, 0x7610, R0 ; /* 0x00007610020d7816 */
/* 0x002fe20000000000 */
/*0150*/ IMAD.SHL.U32 R2, R8, 0x2, RZ ; /* 0x0000000208027824 */
/* 0x000fc800078e00ff */
/*0160*/ IMAD.WIDE.U32 R2, R2, R3, c[0x0][0x160] ; /* 0x0000580002027625 */
/* 0x000fe400078e0003 */
/*0170*/ HMMA.16816.F16 R6, R12, R6, RZ ; /* 0x000000060c06723c */
/* 0x000f5e00000008ff */
/*0180*/ NOP ; /* 0x0000000000007918 */
/* 0x000fd00000000000 */
/*0190*/ STG.E.U8 [R2.64], R6 ; /* 0x0000000602007986 */
/* 0x020fe2000c101104 */
/*01a0*/ SHF.R.U32.HI R5, RZ, 0x18, R6.reuse ; /* 0x00000018ff057819 */
/* 0x100fe40000011606 */
/*01b0*/ SHF.R.U32.HI R9, RZ, 0x10, R6.reuse ; /* 0x00000010ff097819 */
/* 0x100fe20000011606 */
/*01c0*/ STG.E.U8 [R2.64+0x4], R7 ; /* 0x0000040702007986 */
/* 0x000fe2000c101104 */
/*01d0*/ SHF.R.U32.HI R11, RZ, 0x8, R6 ; /* 0x00000008ff0b7819 */
/* 0x000fe40000011606 */
/*01e0*/ SHF.R.U32.HI R13, RZ, 0x18, R7.reuse ; /* 0x00000018ff0d7819 */
/* 0x100fe20000011607 */
/*01f0*/ STG.E.U8 [R2.64+0x3], R5 ; /* 0x0000030502007986 */
/* 0x000fe2000c101104 */
/*0200*/ SHF.R.U32.HI R15, RZ, 0x10, R7.reuse ; /* 0x00000010ff0f7819 */
/* 0x100fe40000011607 */
/*0210*/ SHF.R.U32.HI R17, RZ, 0x8, R7 ; /* 0x00000008ff117819 */
/* 0x000fe20000011607 */
/*0220*/ STG.E.U8 [R2.64+0x2], R9 ; /* 0x0000020902007986 */
/* 0x000fe8000c101104 */
/*0230*/ STG.E.U8 [R2.64+0x1], R11 ; /* 0x0000010b02007986 */
/* 0x000fe8000c101104 */
/*0240*/ STG.E.U8 [R2.64+0x7], R13 ; /* 0x0000070d02007986 */
/* 0x000fe8000c101104 */
/*0250*/ STG.E.U8 [R2.64+0x6], R15 ; /* 0x0000060f02007986 */
/* 0x000fe8000c101104 */
/*0260*/ STG.E.U8 [R2.64+0x5], R17 ; /* 0x0000051102007986 */
/* 0x000fe2000c101104 */
/*0270*/ EXIT ; /* 0x000000000000794d */
(由于达到答案中的字符限制,我删除了非必要的行)。