From 5764c550ed5b891e71107857822d7143cd510796 Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Wed, 28 Oct 2020 23:25:34 +0100 Subject: [PATCH] std/crypto: vectorize Salsa20 20% faster on x86_64, slower on aarch64 as usual :/ --- lib/std/crypto/salsa20.zig | 215 ++++++++++++++++++++++++++++++++++--- 1 file changed, 201 insertions(+), 14 deletions(-) diff --git a/lib/std/crypto/salsa20.zig b/lib/std/crypto/salsa20.zig index 2e0c78726b..7d367367fe 100644 --- a/lib/std/crypto/salsa20.zig +++ b/lib/std/crypto/salsa20.zig @@ -9,11 +9,185 @@ const crypto = std.crypto; const debug = std.debug; const math = std.math; const mem = std.mem; +const Vector = std.meta.Vector; const Poly1305 = crypto.onetimeauth.Poly1305; const Blake2b = crypto.hash.blake2.Blake2b; const X25519 = crypto.dh.X25519; +const Salsa20VecImpl = struct { + const Lane = Vector(4, u32); + const Half = Vector(2, u32); + const BlockVec = [4]Lane; + + fn initContext(key: [8]u32, d: [4]u32) BlockVec { + const c = "expand 32-byte k"; + const constant_le = comptime [4]u32{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + Lane{ key[0], key[1], key[2], key[3] }, + Lane{ key[4], key[5], key[6], key[7] }, + Lane{ constant_le[0], constant_le[1], constant_le[2], constant_le[3] }, + Lane{ d[0], d[1], d[2], d[3] }, + }; + } + + inline fn rot(x: Lane, comptime n: u5) Lane { + return (x << @splat(4, @as(u5, n))) | (x >> @splat(4, @as(u5, 1 +% ~n))); + } + + inline fn salsa20Core(x: *BlockVec, input: BlockVec, comptime feedback: bool) void { + const n1n2n3n0 = Lane{ input[3][1], input[3][2], input[3][3], input[3][0] }; + const n1n2 = Half{ n1n2n3n0[0], n1n2n3n0[1] }; + const n3n0 = Half{ n1n2n3n0[2], n1n2n3n0[3] }; + const k0k1 = Half{ input[0][0], input[0][1] }; + const k2k3 = Half{ input[0][2], input[0][3] }; + const k4k5 = Half{ input[1][0], input[1][1] }; + const k6k7 = Half{ input[1][2], input[1][3] }; + const n0k0 = Half{ n3n0[1], k0k1[0] }; + const k0n0 = Half{ n0k0[1], n0k0[0] }; + const k4k5k0n0 = Lane{ k4k5[0], k4k5[1], k0n0[0], k0n0[1] }; + const k1k6 = Half{ k0k1[1], k6k7[0] }; + const k6k1 = Half{ k1k6[1], k1k6[0] }; + const n1n2k6k1 = Lane{ n1n2[0], n1n2[1], k6k1[0], k6k1[1] }; + const k7n3 = Half{ k6k7[1], n3n0[0] }; + const n3k7 = Half{ k7n3[1], k7n3[0] }; + const k2k3n3k7 = Lane{ k2k3[0], k2k3[1], n3k7[0], n3k7[1] }; + + var diag0 = input[2]; + var diag1 = @shuffle(u32, k4k5k0n0, undefined, [_]i32{ 1, 2, 3, 0 }); + var diag2 = @shuffle(u32, n1n2k6k1, undefined, [_]i32{ 1, 2, 3, 0 }); + var diag3 = @shuffle(u32, k2k3n3k7, undefined, [_]i32{ 1, 2, 3, 0 }); + + const start0 = diag0; + const start1 = diag1; + const start2 = diag2; + const start3 = diag3; + + var i: usize = 0; + while (i < 20) : (i += 2) { + var a0 = diag1 +% diag0; + diag3 ^= rot(a0, 7); + var a1 = diag0 +% diag3; + diag2 ^= rot(a1, 9); + var a2 = diag3 +% diag2; + diag1 ^= rot(a2, 13); + var a3 = diag2 +% diag1; + diag0 ^= rot(a3, 18); + + var diag3_shift = @shuffle(u32, diag3, undefined, [_]i32{ 3, 0, 1, 2 }); + var diag2_shift = @shuffle(u32, diag2, undefined, [_]i32{ 2, 3, 0, 1 }); + var diag1_shift = @shuffle(u32, diag1, undefined, [_]i32{ 1, 2, 3, 0 }); + diag3 = diag3_shift; + diag2 = diag2_shift; + diag1 = diag1_shift; + + a0 = diag3 +% diag0; + diag1 ^= rot(a0, 7); + a1 = diag0 +% diag1; + diag2 ^= rot(a1, 9); + a2 = diag1 +% diag2; + diag3 ^= rot(a2, 13); + a3 = diag2 +% diag3; + diag0 ^= rot(a3, 18); + + diag1_shift = @shuffle(u32, diag1, undefined, [_]i32{ 3, 0, 1, 2 }); + diag2_shift = @shuffle(u32, diag2, undefined, [_]i32{ 2, 3, 0, 1 }); + diag3_shift = @shuffle(u32, diag3, undefined, [_]i32{ 1, 2, 3, 0 }); + diag1 = diag1_shift; + diag2 = diag2_shift; + diag3 = diag3_shift; + } + + if (feedback) { + diag0 +%= start0; + diag1 +%= start1; + diag2 +%= start2; + diag3 +%= start3; + } + + const x0x1x10x11 = Lane{ diag0[0], diag1[1], diag0[2], diag1[3] }; + const x12x13x6x7 = Lane{ diag1[0], diag2[1], diag1[2], diag2[3] }; + const x8x9x2x3 = Lane{ diag2[0], diag3[1], diag2[2], diag3[3] }; + const x4x5x14x15 = Lane{ diag3[0], diag0[1], diag3[2], diag0[3] }; + + x[0] = Lane{ x0x1x10x11[0], x0x1x10x11[1], x8x9x2x3[2], x8x9x2x3[3] }; + x[1] = Lane{ x4x5x14x15[0], x4x5x14x15[1], x12x13x6x7[2], x12x13x6x7[3] }; + x[2] = Lane{ x8x9x2x3[0], x8x9x2x3[1], x0x1x10x11[2], x0x1x10x11[3] }; + x[3] = Lane{ x12x13x6x7[0], x12x13x6x7[1], x4x5x14x15[2], x4x5x14x15[3] }; + } + + fn hashToBytes(out: *[64]u8, x: BlockVec) void { + var i: usize = 0; + while (i < 4) : (i += 1) { + mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]); + mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]); + mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]); + mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]); + } + } + + fn salsa20Xor(out: []u8, in: []const u8, key: [8]u32, d: [4]u32) void { + var ctx = initContext(key, d); + var x: BlockVec = undefined; + var buf: [64]u8 = undefined; + var i: usize = 0; + while (i + 64 <= in.len) : (i += 64) { + salsa20Core(x[0..], ctx, true); + hashToBytes(buf[0..], x); + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < 64) : (j += 1) { + xout[j] = xin[j]; + } + j = 0; + while (j < 64) : (j += 1) { + xout[j] ^= buf[j]; + } + ctx[2][0] +%= 1; + if (ctx[2][0] == 0) { + ctx[2][1] += 1; + } + } + if (i < in.len) { + salsa20Core(x[0..], ctx, true); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < in.len % 64) : (j += 1) { + xout[j] = xin[j] ^ buf[j]; + } + } + } + + fn hsalsa20(input: [16]u8, key: [32]u8) [32]u8 { + var c: [4]u32 = undefined; + for (c) |_, i| { + c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + } + const ctx = initContext(keyToWords(key), c); + var x: BlockVec = undefined; + salsa20Core(x[0..], ctx, false); + var out: [32]u8 = undefined; + mem.writeIntLittle(u32, out[0..4], x[0][0]); + mem.writeIntLittle(u32, out[4..8], x[1][1]); + mem.writeIntLittle(u32, out[8..12], x[2][2]); + mem.writeIntLittle(u32, out[12..16], x[3][3]); + mem.writeIntLittle(u32, out[16..20], x[1][2]); + mem.writeIntLittle(u32, out[20..24], x[1][3]); + mem.writeIntLittle(u32, out[24..28], x[2][0]); + mem.writeIntLittle(u32, out[28..32], x[2][1]); + return out; + } +}; + const Salsa20NonVecImpl = struct { const BlockVec = [16]u32; @@ -49,7 +223,7 @@ const Salsa20NonVecImpl = struct { }; } - inline fn salsa20Core(x: *BlockVec, input: BlockVec) void { + inline fn salsa20Core(x: *BlockVec, input: BlockVec, comptime feedback: bool) void { const arx_steps = comptime [_]QuarterRound{ Rp(4, 0, 12, 7), Rp(8, 4, 0, 9), Rp(12, 8, 4, 13), Rp(0, 12, 8, 18), Rp(9, 5, 1, 7), Rp(13, 9, 5, 9), Rp(1, 13, 9, 13), Rp(5, 1, 13, 18), @@ -67,6 +241,12 @@ const Salsa20NonVecImpl = struct { x[r.a] ^= math.rotl(u32, x[r.b] +% x[r.c], r.d); } } + if (feedback) { + j = 0; + while (j < 16) : (j += 1) { + x[j] +%= input[j]; + } + } } fn hashToBytes(out: *[64]u8, x: BlockVec) void { @@ -75,21 +255,13 @@ const Salsa20NonVecImpl = struct { } } - fn contextFeedback(x: *BlockVec, ctx: BlockVec) void { - var i: usize = 0; - while (i < 16) : (i += 1) { - x[i] +%= ctx[i]; - } - } - fn salsa20Xor(out: []u8, in: []const u8, key: [8]u32, d: [4]u32) void { var ctx = initContext(key, d); var x: BlockVec = undefined; var buf: [64]u8 = undefined; var i: usize = 0; while (i + 64 <= in.len) : (i += 64) { - salsa20Core(x[0..], ctx); - contextFeedback(&x, ctx); + salsa20Core(x[0..], ctx, true); hashToBytes(buf[0..], x); var xout = out[i..]; const xin = in[i..]; @@ -104,8 +276,7 @@ const Salsa20NonVecImpl = struct { ctx[9] += @boolToInt(@addWithOverflow(u32, ctx[8], 1, &ctx[8])); } if (i < in.len) { - salsa20Core(x[0..], ctx); - contextFeedback(&x, ctx); + salsa20Core(x[0..], ctx, true); hashToBytes(buf[0..], x); var xout = out[i..]; @@ -124,7 +295,7 @@ const Salsa20NonVecImpl = struct { } const ctx = initContext(keyToWords(key), c); var x: BlockVec = undefined; - salsa20Core(x[0..], ctx); + salsa20Core(x[0..], ctx, false); var out: [32]u8 = undefined; mem.writeIntLittle(u32, out[0..4], x[0]); mem.writeIntLittle(u32, out[4..8], x[5]); @@ -138,7 +309,7 @@ const Salsa20NonVecImpl = struct { } }; -const Salsa20Impl = Salsa20NonVecImpl; +const Salsa20Impl = if (std.Target.current.cpu.arch == .x86_64) Salsa20VecImpl else Salsa20NonVecImpl; fn keyToWords(key: [32]u8) [8]u32 { var k: [8]u32 = undefined; @@ -381,6 +552,22 @@ pub const SealedBox = struct { } }; +const htest = @import("test.zig"); + +test "(x)salsa20" { + const key = [_]u8{0x69} ** 32; + const nonce = [_]u8{0x42} ** 8; + const msg = [_]u8{0} ** 20; + var c: [msg.len]u8 = undefined; + + Salsa20.xor(&c, msg[0..], 0, key, nonce); + htest.assertEqual("30ff9933aa6534ff5207142593cd1fca4b23bdd8", c[0..]); + + const extended_nonce = [_]u8{0x42} ** 24; + XSalsa20.xor(&c, msg[0..], 0, key, extended_nonce); + htest.assertEqual("b4ab7d82e750ec07644fa3281bce6cd91d4243f9", c[0..]); +} + test "xsalsa20poly1305" { var msg: [100]u8 = undefined; var msg2: [msg.len]u8 = undefined;