Thursday, 1 March 2012

A Caesar Cypher? Using SIMD? Why?

There's no easy answer to that. I had applied to join a group on some social networking site which was for assembly language programmers, and to check that I hadn't got them confused with a flat-pack furnishings fanciers, the polite email asked whether I would send them a simple encryption routine written in assembly which didn't rely on xor. Since I seem to be playing with vector instructions, I thought I'd have a go using some SSE (<=4.1) instructions.

I might as well paste what I have and then talk through what it does. So here we go, a Makefile and ASM source file for 64-bit linux:

---------------------------[ Makefile ]--------------------------

build: asmcrypt.o
 ld --dynamic-linker /lib64/ld-linux.so.2 -o asmcrypt asmcrypt.o

asmcrypt.o: asmcrypt.s
 as -gstabs -o asmcrypt.o asmcrypt.s

clean:
 @rm -f *.o asmcrypt

-------------------------------8<-------------------------------
And a reasonably well-commented source file:
--------------------------[ asmcrypt.s ]-------------------------
.section .data
.align 0x20
plaintxt:
        .asciz  "AA:ZZ/aa:zz The quick brown fox, etc.\n"
txtlen= . - plaintxt
rot:
        .byte   13                           # The rotation value
upmin:
        .byte   'A'                          # The uppercase ASCII lower-bound
upmax:
        .byte   'Z'                          # The uppercase ASCII upper-bound
lomin:
        .byte   'a'                          # The lowercase ASCII lower-bound
lomax:
        .byte   'z'                          # The lowercase ASCII lower-bound

.section .bss
.align 0x20
.lcomm  Rvec, 0x10                           # Mem in which to expand 'rot'
.lcomm  Avec, 0x10                           # Mem in which to expand 'upmin'
.lcomm  Zvec, 0x10                           # Mem in which to expand 'upmax'
.lcomm  avec, 0x10                           # Mem in which to expand 'lomin'
.lcomm  zvec, 0x10                           # Mem in which to expand 'lomax'

.section .text

.LrotateRange:
        movdqa  %xmm14, %xmm1                # copy of lower-bound byte-vector 
        pcmpgtb %xmm0, %xmm1                 # create mask where <'A'; chars in src
        movdqa  %xmm0, %xmm2                 # duplicate chars 
        pcmpgtb %xmm13, %xmm2                # create mask where >'Z'; chars in dest
        por     %xmm2, %xmm1                 # derive mask for <'A' || >'Z'

        movdqa  %xmm15, %xmm2                # duplicate rot-vector 
        paddb   %xmm0, %xmm2                 # add chars to rotation 
        movdqa  %xmm1, %xmm3                 # duplicate mask 
        pandn   %xmm2, %xmm3                 # retain upper-case chars -- xmm2 now spare

        movdqa  %xmm3, %xmm4                 # duplicate result
        movdqa  %xmm3, %xmm5                 # and again 
        pcmpgtb %xmm13, %xmm4                # create mask where >'Z' 
        pxor    %xmm6, %xmm6                 # zero register
        pcmpgtb %xmm5, %xmm6                 # create mask where < 0
        por     %xmm6, %xmm4                 # derive mask where >'Z' || < 0
        movdqa  %xmm4, %xmm5                 # duplicate mask 

        pand    %xmm13, %xmm4                # create subtraction vector
        psubb   %xmm4, %xmm3                 # subtract 'Z' from out-of-bounds caps

        pand    %xmm14, %xmm4                # create addition vector
        paddb   %xmm4, %xmm3                 # add 'A' to out-of-bounds caps

        pand    %xmm1, %xmm0                 # Clear values we're going to 'set'
        paddb   %xmm3, %xmm0                 # Upper case values are done

        retq

.globl _start
_start:
        pushq   %rbp
        mov     %rsp, %rbp                   # Function prologue
        sub     $txtlen, %rsp                # Create space for the result
        and     $-0x10, %rsp                 # Align RSP to a 16-byte boundary

        xor     %rdx, %rdx                   # clear RDX
        mov     $0x05, %rbx                  # move loop-count into RBX
        leaq    Rvec, %rdi                   # copy start-address into RDI
.LexpandAgain:
        movb    rot(%rdx), %al               # copy source-byte into AL
        mov     $0x10, %rcx                  # set RCX(stosb rep-count to 16)
rep     stosb                                # repeat-store AL->&RDI
        inc     %rdx                         # bump loop index
        dec     %rbx                         # decrement loop counter
        jnz     .LexpandAgain                # repeat while RBX != 0

        xor     %rax, %rax                   # Zero RAX, indexing register
        mov     $txtlen, %rdx                # Copy txtlen value into RDX

        movdqa  Rvec, %xmm15                 # Copy rotation byte-vector into XMM15
.LcoreLoop:
        movdqa  plaintxt(%rax), %xmm0        # Copy chars to register
        movdqa  Avec, %xmm14                 # Copy low val byte-vec into XMM14
        movdqa  Zvec, %xmm13                 # Copy high val byte-vec into XMM13
        call    .LrotateRange                # Call routine (no ABI here, thanks)
        movdqa  avec, %xmm14                 # Repeat for lower-case values
        movdqa  zvec, %xmm13
        call    .LrotateRange                # Call routine again
        movdqa  %xmm0, (%rsp, %rax)          # Store rotated result on the stack

        add     $0x10, %rax                  # Add 16 to indexing register
        sub     $0x10, %rdx                  # Sub 16 from textlen
        jg      .LcoreLoop                   # Repeat cycle for remaining chars

        mov     $1, %rax                     # sys_write syscall 
        mov     $1, %rdi                     # stdout
        mov     %rsp, %rsi                   # pointer to result[0]
        mov     $txtlen, %rdx                # number of chars to print
        syscall

        movq    $0x3C, %rax                  # 'sysexit' 
        xorq    %rdi, %rdi
        syscall
-------------------------------8<-------------------------------

Right, so we may as well get started. What does it do? Well, it's not much of an encryption function, that's for sure, since it deliberately only rotates letters and preserves their case. I'm not sure anyone would have any trouble cracking that. Anyway, here's the .data section:
.section .data
.align 0x20
plaintxt:
        .asciz  "AA:ZZ/aa:zz The quick brown fox, etc.\n"
txtlen= . - plaintxt
rot:
        .byte   13         # The rotation value
upmin:
        .byte   'A'        # The uppercase ASCII lower-bound
upmax:
        .byte   'Z'        # The uppercase ASCII upper-bound
lomin:
        .byte   'a'        # The lowercase ASCII lower-bound
lomax:
        .byte   'z'        # The lowercase ASCII lower-bound

This is fairly straightforward. I have aligned the start of the plaintxt ASCII string to a 16-byte boundary using the .align directive. This allows me to use movdqa later in the application. I have defined some single-byte values which I expand into byte-vectors for use with SSE instructions so that I can do byte-wise addition and subtraction from each byte in an XMM register. The way I expand the byte values into a 16-byte wide vector format is as follows:
        xor     %rdx, %rdx                   # clear RDX
        mov     $0x05, %rbx                  # move loop-count into RBX
        leaq    Rvec, %rdi                   # copy start-address into RDI
.LexpandAgain:
        movb    rot(%rdx), %al               # copy source-byte into AL
        mov     $0x10, %rcx                  # set RCX (stosb rep-count) to 16
rep     stosb                                # repeat-store AL->&RDI
        inc     %rdx                         # bump loop index
        dec     %rbx                         # decrement loop counter
        jnz     .LexpandAgain                # repeat while RBX != 0

This approach takes advantage of the fact that the source bytes are contiguous in memory, as are the destination addresses. It's very simple to reference a new source-byte for every 16 bytes copied by rep stosb.
        xor     %rax, %rax                   # Zero RAX, indexing register
        mov     $txtlen, %rdx                # Copy txtlen value into RDX
        movdqa  Rvec, %xmm15                 # Copy rotation byte-vector into XMM15
.LcoreLoop:
        movdqa  plaintxt(%rax), %xmm0        # Copy chars to register
        movdqa  Avec, %xmm14                 # Copy low val byte-vec into XMM14
        movdqa  Zvec, %xmm13                 # Copy high val byte-vec into XMM13
        call    .LrotateRange                # Call routine (no ABI here, thanks)

The first three lines simply initialise a few variables we're going to use - the inline comments suffice to describe what's going on. Following the .LcoreLoop label, movdqa moves 128 bits/16 bytes worth of data into register XMM0 from the memory address plaintxt plus the value stored in RAX. At the end of the loop (not shown in the snippet) I add 16 to RAX to advance the memory pointer to the next 16 bytes. Since this is an assembler-only function the call to .LrotateRange doesn't have to abide by the System V ABI and I set the value of four registers outside the call which affect its execution: XMM0, which contains the text, XMM15, which contains the rotate-by value, and XMM14 and XMM13 which contain the upper and lower bounds for selecting characters to rotate.

The function .LrotateRange uses the upper and lower bounds to define the byte values on which it will operate: in the first call they are the byte values of 'Z' and 'A' respectively. The vector representations of these values in memory are laid out as follows:
Zvec = 'ZZZZZZZZZZZZZZZZ' = 0x5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a
Avec = 'AAAAAAAAAAAAAAAA' = 0x41414141414141414141414141414141
The subsequent call uses the values for 'z' and 'a'.

Let's have a look at the SIMD instructions in .LrotateRange, then:
.LrotateRange:
        movdqa  %xmm14, %xmm1                # copy of lower-bound byte-vector 
        pcmpgtb %xmm0, %xmm1                 # create mask where <'A'; chars in src
        movdqa  %xmm0, %xmm2                 # duplicate chars 
        pcmpgtb %xmm13, %xmm2                # create mask where >'Z'; chars in dest
        por     %xmm2, %xmm1                 # derive mask for <'A' || >'Z'

I can see now why Intel have made so much of the fact that AVX instructions are "non-destrutive" - in that they don't clobber the value of one of the "source" operands. I make heavy use of the pcmpgtb, a byte-wise greater-than test, but it overwrites the contents of the destination operand. If I may borrow from Intel's instruction set reference:
[pcmpgtb] performs a SIMD signed compare for the greater value of the packed byte ... in the destination operand and the source operand. If a data element in the destination operand is greater than the corresponding date element in the source operand, the corresponding data element in the destination operand is set to all 1s; otherwise, it is set to all 0s.
This means that I do plenty of copying between registers of character and mask values just so that it can be overwritten by pcmpgtb. Anyway, I digress. So, what's going on?
        movdqa  %xmm14, %xmm1                # copy of lower-bound byte-vector 
movdqa copies 16 bytes of data from XMM14 to XMM1 in preparation for the pcmpgtb instruction that follows.
        pcmpgtb %xmm0, %xmm1                 # create mask where <'A'; chars in src
XMM0 is the source operand; XMM1 is the destination operand (remember this is GNU assembler). XMM0 contains the ASCII bytes from the input string, while XMM1 contains a copy of the lower-bound vector:
XMM0:  . .n.w.o.r.b. .k.c.i.u.q. .e.h.T
XMM0:  206e776f7262206b6369757120656854
XMM1:  41414141414141414141414141414141
I've written it backwards since when you look at register and memory contents in GDB the "small" end (bit 0) is on the right and the high-bits are on the left (it seems to me to be easier not to fight these things). Anyway, the comparison between XMM0 and XMM1 is whether XMM1 is greater. This is important since you may have noticed that there is no pcmpltb instruction - you have to reverse the operands in pcmpgtb to get that functionality (though, do note that reversing the operands would not give you the inverse result where some of the values are equal). The result of the comparison sets the bits in the destination register where the result is 'true':
XMM0:  . .n.w.o.r.b. .k.c.i.u.q. .e.h.T
XMM0:  206e776f7262206b6369757120656854
XMM1:  41414141414141414141414141414141
XMM1:  FF0000000000FF0000000000FF000000
OK - so we've found some spaces...
        movdqa  %xmm0, %xmm2                 # duplicate chars 
        pcmpgtb %xmm13, %xmm2                # create mask where >'Z'; chars in dest
Again, we copy the operand for the destination register, but this time instead of copying the bounding value, we copy the characters, since we want to know which characters are greater than the upper-bound.
XMM2:  . .n.w.o.r.b. .k.c.i.u.q. .e.h.T
XMM2:  206e776f7262206b6369757120656854
XMM13: 5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a
XMM2:  00FFFFFFFFFF00FFFFFFFFFF00FFFF00
There we have it, we've found all the lower-case characters. I should probably mention my usual ASCII table reference at this point!
        por     %xmm2, %xmm1                 # derive mask for <'A' || >'Z'
All that remains to do is combine the two masks in XMM1 and XMM2 and we should have a mask which can tell us which characters fall within our upper and lower bound:
XMM2:  00FFFFFFFFFF00FFFFFFFFFF00FFFF00
XMM1:  FF0000000000FF0000000000FF000000
XMM1:  FFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00
There we go - the right result, albeit as an inverse of what we really want, but no matter, we can deal with that.

The next part is a bit more complicated, so bear with me while I do my best to pick it apart.
        movdqa  %xmm15, %xmm2                # duplicate rot-vector 
        paddb   %xmm0, %xmm2                 # add chars to rotation 
        movdqa  %xmm1, %xmm3                 # duplicate mask 
        pandn   %xmm2, %xmm3                 # retain upper-case chars -- xmm2 now spare
In the section above, I duplicate the contents of the register containing the byte-vector of rotate-byte values, then add to it input characters - all of them. After duplicating the mask which we've so lovingly just created, I use it as the destination operand to pandn (note the 'n' at the end: packed-and-not) to clear those values which were flagged as 0xFF:
XMM0:  . .n.w.o.r.b. .k.c.i.u.q. .e.h.T
XMM0:  206e776f7262206b6369757120656854 +
XMM2:  0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d = 
XMM2:  2d7b847c7f6f2d787076827e2d727561 &!
XMM3:  FFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00 =
XMM3:  00000000000000000000000000000061
So there are our rotated upper-case values. However, we need to do some more work since we may have rotated the character value out of the end of the range. In fact, that's exactly what has happened to the letter 'T': 0x61 is the ASCII value for the letter 'a'! To fix that need to select those values greater than 'Z', subtract from them the ASCII value for 'Z' and add the value for 'A'. Simple?

There's an additional problem. Let's consider lowercase values. Given that the algorithm should work for any positive rotation less than 26, we need to check that the values haven't overflowed the max value of a signed byte, which is 127. The value of 'z' is 122. There's a likely source of error. This next code block addresses that problem:
        movdqa  %xmm3, %xmm4                 # duplicate result
        movdqa  %xmm3, %xmm5                 # and again 
        pcmpgtb %xmm13, %xmm4                # create mask where >'Z' 
        pxor    %xmm6, %xmm6                 # zero register
        pcmpgtb %xmm5, %xmm6                 # create mask where < 0
        por     %xmm6, %xmm4                 # derive mask where >'Z' || < 0

I create a couple of copies of the rotated upper-case values (for destruction later). Then I compare one of them with the upper-bound mask, which in the current case identifies the 'T' as now being out-of-bounds:
XMM4:  00000000000000000000000000000061 >
XMM13: 5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a5a
XMM4:  000000000000000000000000000000FF
I then zero register XMM6, since I want to do a comparison against zero. Remember we need to check for values which have overflowed a signed byte? Well, if we add 13d to 'z', for example, we get the byte value 0x7a + 0x0d = 0x87, which as an unsigned value is 135d but as a signed value is -121d (I think). pcmpgtb treats its operands as signed values, so we need to catch cases which are greater than the upper bound and or less than zero. In my current example text, no values need this treatment since ROT13('T') -> 'a', but by way of example let's assuming we're doing the second pass with lowercase ASCII upper and lower bounds:
XMM0:  . .n.w.o.r.b. .k.c.i.u.q. .e.h.T
XMM0:  206e776f7262206b6369757120656854 +
XMM2:  0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d0d = 
XMM2:  2d7b847c7f6f2d787076827e2d727561 &!
XMM3:  FF0000000000FF0000000000FF0000FF =
XMM3:  007b847c7f6f00787076827e00727500
Then we test whether zero is greater than the source operand:
XMM6:  00000000000000000000000000000000 >
XMM5:  007b847c7f6f00787076827e00727500 =
XMM6:  0000FF00000000000000FF0000000000
In this case the 'w' and 'u' characters which were not detected by the >(byte)'z' test have been selected. We can now combine those results with por to create a mask and perform some further (simple) arithmetic to subtract the upper-bound and add the lower-bound:
        pand    %xmm13, %xmm4                # create subtraction vector
        psubb   %xmm4, %xmm3                 # subtract 'Z' from out-of-bounds caps

        pand    %xmm14, %xmm4                # create addition vector
        paddb   %xmm4, %xmm3                 # add 'A' to out-of-bounds caps

        pand    %xmm1, %xmm0                 # Clear values we're going to 'set'
        paddb   %xmm3, %xmm0                 # Upper case values are done

        retq
I'm going to skip over this quite quickly since it's fairly easy. Having copied the byte-mask which identified characters which have now exceeded the upper-bound (or are less than zero) I subtract a masked vector of upper-bound values from the character values, before adding a masked set of lower-bound values. Once any out-of-bounds values have been fixed, I clear the byte in the input text and add the rotated back into its original position.

The function .LcoreLoop is called again for each 16 input bytes, but on the second pass it rotates lowercase characters.

The remaining code in the complete listing is nothing you probably haven't seen before, so I'll skip walking through it line by line. The loop tests whether there's any remaining text to process, and if not the application proceeds to write the "ciphertext" (ha, there's a joke) to stdout using sys_write, before exiting.

For something entirely more professional, do have a look at Agner Fog's optimisation library in which he uses SIMD instructions to do fast string operations. The one which I remember reading most recently was his implementation of strlen. Very clever stuff!

No comments:

Post a Comment