Friday, 17 February 2012

AVX matrix-multiplication, or something like it

It's been a while since the last post, and I can confidently say that I understand one or two percent of the new (well, new to me) world of AVX instructions. There was the not so not-so-brief incident involving lots of head scratching about why my test implementation using vgatherqpd would cause a SIGILL exception on my Sandy Bridge laptop. I guess cpuid does have another use outside timing loops ;)

Anyway, I've been wondering about how to do some kind of matrix multiplication using AVX instructions (yes, Sandy Bridge supports AVX but not AVX2, nor for that matter FMA). Manipulating data into the right place with AVX isn't particularly easy, so if you were to try to start on a 2x2 matrix, you might wind up having problems flipping the double-precision values across the high and low lanes in the YMM registers. Say again?

OK, to take that last point, let's say you have a couple of 2x2 matrices which look a bit like this:
[ 1, 2 ] x [ 5, 6 ]
[ 3, 4 ]   [ 7, 8 ]
then (some of) the multiplication you need to do is 1x5, 2x7, 3x6, 4x8. Consider how the matrices are written in memory, though:
arr0:
 .double 1, 2, 3, 4
arr1:
 .double 5, 6, 7, 8
You can't simply use a vmulpd instruction since the 6 and 7 need to be swapped for that work. It turns out that swapping those two values over involves a Byzantine set of register manipulations: copy the high double quad-word (128 bits) to a spare low DQW (vextractf128), copy the 6 and 7 into the same low-DQW (vblendpd), swap them around to be 7 and 6 (vpermilpd), copy them back into the target DQW (vblendpd), then copy the DWQ we moved initially back into the original high DWQ (vinsertf128/vperm2f128). vgatherqpd would probably make all of this easier. Anyway, I digress.

It turns out that if you want to write a small test app, you can make your life easy by making the matrices 4x4 in size: with 4x4 matrices of double-precision FP values a row's worth of data fits much more naturally into the YMM registers. Here's a sample calculation which is calculated below:
[  1,  3,  5,  7 ] x [  2,  4,  6,  8 ]
[  9, 11, 13, 15 ]   [ 10, 12, 14, 16 ]
[ 17, 19, 21, 23 ]   [ 18, 20, 22, 24 ]
[ 25, 27, 29, 31 ]   [ 26, 28, 30, 32 ]

The vbroadcastsd instruction
The vbroadcastsd instruction can copy a single double-precision value into each of the four slots in a YMM register from a single 64-bit memory location. Handily, the values which you would need to multiply by that value are all contiguous in memory, so given the correct alignment we can use vmovapd to copy the multipliers to a second YMM register. vmulpd then multiplies the four pairs of numbers together as shown below (I've used the same register notation as the Intel/AMD manuals):
255[ 1, 1, 1, 1 ]0  x
255[ 8, 6, 4, 2 ]0


In the example below, I've let the first pass exist outside the main loop since that way I can skip a completely redundant addition-to-zero step (as well as zeroing those registers in the first place). Subsequent products of the vector multiplication within the loop are added to the values already stored there. Happily, doing the multiplication this way permits the values to be written directly from the registers to contiguous memory to form the correct sequence of double-precision values in the array. The contents of the register YMM12 contain the following after each relevant set of multiplies and adds:
pre-loop:  1*2,                1*4,                1*6,                1*8
loop 0:    1*2+3*10,           1*4+3*12,           1*6+3*14,           1*8+3*16
loop 1:    1*2+3*10+5*18,      1*4+3*12+5*20,      1*6+3*14+5*22,      1*8+3*16+5*24
loop 2:    1*2+3*10+5*18+7*26, 1*4+3*12+5*20+7*28, 1*6+3*14+5*22+7*30, 1*8+3*16+5*24+7*32
The following code doesn't write anything to stdout but it's easy enough to step through it with GDB. My "Makefile" this time was a batch file, since VMWare's Player doesn't yet support AVX instructions. The contents of make.bat are simply: gcc -gstabs -o avxmul.exe -m64 avxmul.s

-------------------------------8<-------------------------------
.section .rodata
        .align 0x20
arr0:
        .double 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31
arr1:
        .double 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32

.section .bss
        .align 0x20
        .lcomm  result, 0x80

.section .text
        .globl main
main:
        sub          $0x40, %rsp
        call         __main

        leaq         arr0, %rbx
        leaq         arr1, %rdi
        leaq         result, %rdx

        vmovapd      (%rdi), %ymm0                       # load 2, 4, 6, 8

        vbroadcastsd 0x00(%rbx), %ymm1                   # load 1, 1, 1, 1
        vmulpd       %ymm1, %ymm0, %ymm12                # mul  2, 4, 6, 8     = 2, 4, 6, 8

        vbroadcastsd 0x20(%rbx), %ymm1                   # load 9, 9, 9, 9
        vmulpd       %ymm1, %ymm0, %ymm13                # mul  2, 4, 6, 8     = 18, 36, 54, 72

        vbroadcastsd 0x40(%rbx), %ymm1                   # load 17, 17, 17, 17
        vmulpd       %ymm1, %ymm0, %ymm14                # mul   2,  4,  6,  8 = 34, 68, 102, 136

        vbroadcastsd 0x60(%rbx), %ymm1                   # load 25, 25, 25, 25
        vmulpd       %ymm1, %ymm0, %ymm15                # mul   2,  4,  6,  8 = 50, 100, 150, 200
        
        xor          %rax, %rax
        mov          $0x03, %rcx
.Lstart:
        inc          %rax
        add          $0x20, %rdi
        add          $0x20, %rdx

        vmovapd      (%rdi), %ymm0                       # On pass 1, load 10, 12, 14, 16

        vbroadcastsd 0x00(%rbx, %rax, 0x08), %ymm1       # load  3,  3,  3,  3
        vmulpd       %ymm1, %ymm0, %ymm2                 # mul  10, 12, 14, 16 = 30, 36, 42, 48
        vaddpd       %ymm2, %ymm12, %ymm12               # add   2,  4,  6,  8 = 32, 40, 48, 56

        vbroadcastsd 0x20(%rbx, %rax, 0x08), %ymm1       # load 11, 11, 11, 11
        vmulpd       %ymm1, %ymm0, %ymm2                 # mul  10, 12, 14, 16 = 110, 132, 154, 176
        vaddpd       %ymm2, %ymm13, %ymm13               # add  18, 36, 54, 72 = 128, 168, 208, 248

        vbroadcastsd 0x40(%rbx, %rax, 0x08), %ymm1       # load 19, 19,  19,  19
        vmulpd       %ymm1, %ymm0, %ymm2                 # mul  10, 12,  14,  16 = 190, 228, 266, 304
        vaddpd       %ymm2, %ymm14, %ymm14               # add  34, 68, 102, 136 = 224, 296, 368, 440

        vbroadcastsd 0x60(%rbx, %rax, 0x08), %ymm1       # load 27,  27,  27,  27
        vmulpd       %ymm1, %ymm0, %ymm2                 # mul  10,  12,  14,  16 = 270, 324, 378, 432
        vaddpd       %ymm2, %ymm15, %ymm15               # add  50, 100, 150, 200 = 320, 424, 528, 632

        dec          %rcx
        jnz          .Lstart
        
        vmovapd      %ymm12, 0x00+result                 # Write the result to memory. Check in GDB using
        vmovapd      %ymm13, 0x20+result                 # x/16fg &result
        vmovapd      %ymm14, 0x40+result
        vmovapd      %ymm15, 0x60+result
        
        add          $0x40, %rsp
        xor          %rax, %rax
        ret

-------------------------------8<-------------------------------

No comments:

Post a Comment