Learning SIMD then accidentally creating a function that's faster than the default .contains() method

⚓ Rust    📅 2026-02-06    👤 surdeus    👁️ 9      

surdeus

While I was learning, I accidentally found that my SIMD function is faster than the default .contains() method

Here is the benchmark setup:

  • 10 MB vector size
  • 5.000 iterations

The first image shows the case where the target is not in the vector at all, making a full scan of all elements (result: false)

The second image shows the target placed at the very last index to verify if the function actually finds it (result: true)

The machine in the Playground supports SSE/Generic 16 byte

In the assembly for the default .contains(), it loops byte by byte for small data. If the data is large enough (>16 bytes), it calls core::slice::memchr::memchr_aligned (I don't know, just maybe this function call also add overhead)

Initially, I used N=16. The 16 byte version (a2<u8, 16>) uses one SIMD register (xmm0), processing 16 bytes at once

Then I tried N=64, and it was even faster. The assembly shows Loop Unrolling. It loads 4 SIMD registers simultaneously (xmm1 to xmm4), each 16 bytes

1000021060

1000021069

However, for data smaller than 16 bytes. I tested with 15 bytes, the results changed. The a2<16> version was the fastest, followed by the default .contains() method in second place. Interestingly, both a2<64> and a3<64> performed the worst in this scenario. They all perform a linear loop . However, the SIMD version checks for alignment first

1000021075

1000021079

However, I don't quite understand the difference between the generic a2 64 and the a3 64. Can someone explain the difference?

A1 assembly

a1:                                     # @a1
# %bb.0:
	mov	rcx, rsi
	mov	rsi, rdi
	cmp	rcx, 15
	ja	.LBB12_3
# %bb.1:
	test	rcx, rcx
	je	.LBB12_2
# %bb.4:
	dec	rcx
	xor	edi, edi

.LBB12_5:                               # =>This Inner Loop Header: Depth=1
	cmp	byte ptr [rsi + rdi], dl
	sete	al
	je	.LBB12_7
# %bb.6:                                #   in Loop: Header=BB12_5 Depth=1
	cmp	rcx, rdi
	lea	rdi, [rdi + 1]
	jne	.LBB12_5

.LBB12_7:
                                        # kill: def $al killed $al killed $eax
	ret

.LBB12_3:
	push	rax
	movzx	edi, dl
	mov	rdx, rcx
	call	qword ptr [rip + core::slice::memchr::memchr_aligned@GOTPCREL]
	cmp	rax, 1
	sete	al
	add	rsp, 8
                                        # kill: def $al killed $al killed $eax
	ret

.LBB12_2:
	xor	eax, eax
                                        # kill: def $al killed $al killed $eax
	ret
                                        # -- End function

A2 16 assembly

.LCPI0_0:
	.zero	16,10

playground::a2::<u8, 16>: # @playground::a2::<u8, 16>
# %bb.0:
	movabs	rax, 9223372036854775792
	and	rax, rsi
	je	.LBB0_6
# %bb.1:
	xor	ecx, ecx
	movdqa	xmm0, xmmword ptr [rip + .LCPI0_0] # xmm0 = [10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10]

.LBB0_3:                                # =>This Inner Loop Header: Depth=1
	movdqu	xmm1, xmmword ptr [rdi + rcx]
	pcmpeqb	xmm1, xmm0
	pmovmskb	edx, xmm1
	test	edx, edx
	jne	.LBB0_4
# %bb.2:                                #   in Loop: Header=BB0_3 Depth=1
	add	rcx, 16
	cmp	rax, rcx
	jne	.LBB0_3

.LBB0_6:
	and	esi, 15
	je	.LBB0_7
# %bb.8:
	add	rdi, rax
	dec	rsi
	xor	ecx, ecx

.LBB0_9:                                # =>This Inner Loop Header: Depth=1
	cmp	byte ptr [rdi + rcx], 10
	sete	al
	je	.LBB0_5
# %bb.10:                               #   in Loop: Header=BB0_9 Depth=1
	cmp	rsi, rcx
	lea	rcx, [rcx + 1]
	jne	.LBB0_9

.LBB0_5:
                                        # kill: def $al killed $al killed $eax
	ret

.LBB0_4:
	mov	al, 1
                                        # kill: def $al killed $al killed $eax
	ret

.LBB0_7:
	xor	eax, eax
                                        # kill: def $al killed $al killed $eax
	ret
                                        # -- End function

A2 64 assembly

.LCPI1_0:
	.zero	16,10

playground::a2::<u8, 64>: # @playground::a2::<u8, 64>
# %bb.0:
	mov	rdx, rsi
	mov	rsi, rdi
	movabs	rax, 9223372036854775744
	and	rax, rdx
	je	.LBB1_6
# %bb.1:
	xor	ecx, ecx
	movdqa	xmm0, xmmword ptr [rip + .LCPI1_0] # xmm0 = [10,10,10,10,10,10,10,10,10,10,10,10,10,10,10,10]

.LBB1_3:                                # =>This Inner Loop Header: Depth=1
	movdqu	xmm1, xmmword ptr [rsi + rcx]
	movdqu	xmm2, xmmword ptr [rsi + rcx + 16]
	movdqu	xmm3, xmmword ptr [rsi + rcx + 32]
	movdqu	xmm4, xmmword ptr [rsi + rcx + 48]
	pcmpeqb	xmm3, xmm0
	pcmpeqb	xmm1, xmm0
	por	xmm1, xmm3
	pcmpeqb	xmm4, xmm0
	pcmpeqb	xmm2, xmm0
	por	xmm2, xmm4
	por	xmm2, xmm1
	pmovmskb	edi, xmm2
	test	edi, edi
	jne	.LBB1_4
# %bb.2:                                #   in Loop: Header=BB1_3 Depth=1
	add	rcx, 64
	cmp	rax, rcx
	jne	.LBB1_3

.LBB1_6:
	add	rsi, rax
	and	edx, 63
	cmp	edx, 15
	ja	.LBB1_9
# %bb.7:
	test	rdx, rdx
	je	.LBB1_8
# %bb.10:
	dec	rdx
	xor	ecx, ecx

.LBB1_11:                               # =>This Inner Loop Header: Depth=1
	cmp	byte ptr [rsi + rcx], 10
	sete	al
	je	.LBB1_5
# %bb.12:                               #   in Loop: Header=BB1_11 Depth=1
	cmp	rdx, rcx
	lea	rcx, [rcx + 1]
	jne	.LBB1_11

.LBB1_5:
                                        # kill: def $al killed $al killed $eax
	ret

.LBB1_9:
	push	rax
	mov	edi, 10
	call	qword ptr [rip + core::slice::memchr::memchr_aligned@GOTPCREL]
	cmp	rax, 1
	sete	al
	add	rsp, 8
                                        # kill: def $al killed $al killed $eax
	ret

.LBB1_4:
	mov	al, 1
                                        # kill: def $al killed $al killed $eax
	ret

.LBB1_8:
	xor	eax, eax
                                        # kill: def $al killed $al killed $eax
	ret
                                        # -- End function

A3 64 assembly

a3:                                     # @a3
# %bb.0:
	mov	rcx, rsi
	mov	rsi, rdi
	movabs	rax, 9223372036854775744
	and	rax, rcx
	je	.LBB13_5
# %bb.1:
	movd	xmm0, edx
	punpcklbw	xmm0, xmm0              # xmm0 = xmm0[0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7]
	pshuflw	xmm0, xmm0, 0                   # xmm0 = xmm0[0,0,0,0,4,5,6,7]
	pshufd	xmm0, xmm0, 68                  # xmm0 = xmm0[0,1,0,1]
	xor	edi, edi

.LBB13_3:                               # =>This Inner Loop Header: Depth=1
	movdqu	xmm1, xmmword ptr [rsi + rdi]
	movdqu	xmm2, xmmword ptr [rsi + rdi + 16]
	movdqu	xmm3, xmmword ptr [rsi + rdi + 32]
	movdqu	xmm4, xmmword ptr [rsi + rdi + 48]
	pcmpeqb	xmm2, xmm0
	pcmpeqb	xmm4, xmm0
	por	xmm4, xmm2
	pcmpeqb	xmm1, xmm0
	pcmpeqb	xmm3, xmm0
	por	xmm3, xmm1
	por	xmm3, xmm4
	pmovmskb	r8d, xmm3
	test	r8d, r8d
	jne	.LBB13_4
# %bb.2:                                #   in Loop: Header=BB13_3 Depth=1
	add	rdi, 64
	cmp	rax, rdi
	jne	.LBB13_3

.LBB13_5:
	add	rsi, rax
	and	ecx, 63
	cmp	ecx, 15
	ja	.LBB13_8
# %bb.6:
	test	rcx, rcx
	je	.LBB13_7
# %bb.9:
	dec	rcx
	xor	edi, edi

.LBB13_10:                              # =>This Inner Loop Header: Depth=1
	cmp	byte ptr [rsi + rdi], dl
	sete	al
	je	.LBB13_12
# %bb.11:                               #   in Loop: Header=BB13_10 Depth=1
	cmp	rcx, rdi
	lea	rdi, [rdi + 1]
	jne	.LBB13_10

.LBB13_12:
                                        # kill: def $al killed $al killed $eax
	ret

.LBB13_8:
	push	rax
	movzx	edi, dl
	mov	rdx, rcx
	call	qword ptr [rip + core::slice::memchr::memchr_aligned@GOTPCREL]
	cmp	rax, 1
	sete	al
	add	rsp, 8
                                        # kill: def $al killed $al killed $eax
	ret

.LBB13_4:
	mov	al, 1
                                        # kill: def $al killed $al killed $eax
	ret

.LBB13_7:
	xor	eax, eax
                                        # kill: def $al killed $al killed $eax
	ret
                                        # -- End function

Please correct the interpretation if it is wrong, as I relied on an AI's explanation of the assembly code. Also, is there any way to make the function even faster again?

Here is the link to the code : Rust Playground

Also, is there any way to make the generic code look cleaner/more elegant?

1 post - 1 participant

Read full topic

🏷️ Rust_feed