From 0d192ee9ef6a69ca4841b1932579b9178938a6d5 Mon Sep 17 00:00:00 2001 From: Frank Denis <124872+jedisct1@users.noreply.github.com> Date: Tue, 1 Nov 2022 18:49:13 +0100 Subject: [PATCH] std.crypto.onetimeauth.Ghash: make GHASH 2 - 2.5x faster (#13374) Rewrite GHASH to use 128-bit multiplication over non-reversed integers, and up to 8 blocks aggregated reduction. lib/std/crypto/benchmark.zig results: Xeon E5: Before: 1604 MiB/s After: 4005 MiB/s Apple M1: Before: 2769 MiB/s After: 6014 MiB/s This also makes AES-GCM faster by the way. --- lib/std/crypto/aes_gcm.zig | 7 +- lib/std/crypto/ghash.zig | 354 ++++++++++++++++++------------------- 2 files changed, 182 insertions(+), 179 deletions(-) diff --git a/lib/std/crypto/aes_gcm.zig b/lib/std/crypto/aes_gcm.zig index bb35bc9e8a..2a363b86eb 100644 --- a/lib/std/crypto/aes_gcm.zig +++ b/lib/std/crypto/aes_gcm.zig @@ -3,6 +3,7 @@ const assert = std.debug.assert; const crypto = std.crypto; const debug = std.debug; const Ghash = std.crypto.onetimeauth.Ghash; +const math = std.math; const mem = std.mem; const modes = crypto.core.modes; const AuthenticationError = crypto.errors.AuthenticationError; @@ -34,7 +35,8 @@ fn AesGcm(comptime Aes: anytype) type { mem.writeIntBig(u32, j[nonce_length..][0..4], 1); aes.encrypt(&t, &j); - var mac = Ghash.init(&h); + const block_count = (math.divCeil(usize, ad.len, Ghash.block_length) catch unreachable) + (math.divCeil(usize, c.len, Ghash.block_length) catch unreachable); + var mac = Ghash.initForBlockCount(&h, block_count); mac.update(ad); mac.pad(); @@ -66,7 +68,8 @@ fn AesGcm(comptime Aes: anytype) type { mem.writeIntBig(u32, j[nonce_length..][0..4], 1); aes.encrypt(&t, &j); - var mac = Ghash.init(&h); + const block_count = (math.divCeil(usize, ad.len, Ghash.block_length) catch unreachable) + (math.divCeil(usize, c.len, Ghash.block_length) catch unreachable) + 1; + var mac = Ghash.initForBlockCount(&h, block_count); mac.update(ad); mac.pad(); diff --git a/lib/std/crypto/ghash.zig b/lib/std/crypto/ghash.zig index 8f57f9033a..bddd4d0f8d 100644 --- a/lib/std/crypto/ghash.zig +++ b/lib/std/crypto/ghash.zig @@ -1,6 +1,3 @@ -// -// Adapted from BearSSL's ctmul64 implementation originally written by Thomas Pornin - const std = @import("../std.zig"); const builtin = @import("builtin"); const assert = std.debug.assert; @@ -8,6 +5,8 @@ const math = std.math; const mem = std.mem; const utils = std.crypto.utils; +const Precomp = u128; + /// GHASH is a universal hash function that features multiplication /// by a fixed parameter within a Galois field. /// @@ -19,116 +18,132 @@ pub const Ghash = struct { pub const mac_length = 16; pub const key_length = 16; - y0: u64 = 0, - y1: u64 = 0, - h0: u64, - h1: u64, - h2: u64, - h0r: u64, - h1r: u64, - h2r: u64, + const pc_count = if (builtin.mode != .ReleaseSmall) 8 else 1; - hh0: u64 = undefined, - hh1: u64 = undefined, - hh2: u64 = undefined, - hh0r: u64 = undefined, - hh1r: u64 = undefined, - hh2r: u64 = undefined, + hx: [pc_count]Precomp, + acc: u128 = 0, leftover: usize = 0, buf: [block_length]u8 align(16) = undefined, - pub fn init(key: *const [key_length]u8) Ghash { - const h1 = mem.readIntBig(u64, key[0..8]); - const h0 = mem.readIntBig(u64, key[8..16]); - const h1r = @bitReverse(h1); - const h0r = @bitReverse(h0); - const h2 = h0 ^ h1; - const h2r = h0r ^ h1r; + /// Initialize the GHASH state with a key, and a minimum number of block count. + pub fn initForBlockCount(key: *const [key_length]u8, block_count: usize) Ghash { + const h0 = mem.readIntBig(u128, key[0..16]); - if (builtin.mode == .ReleaseSmall) { - return Ghash{ - .h0 = h0, - .h1 = h1, - .h2 = h2, - .h0r = h0r, - .h1r = h1r, - .h2r = h2r, - }; - } else { - // Precompute H^2 - var hh = Ghash{ - .h0 = h0, - .h1 = h1, - .h2 = h2, - .h0r = h0r, - .h1r = h1r, - .h2r = h2r, - }; - hh.update(key); - const hh1 = hh.y1; - const hh0 = hh.y0; - const hh1r = @bitReverse(hh1); - const hh0r = @bitReverse(hh0); - const hh2 = hh0 ^ hh1; - const hh2r = hh0r ^ hh1r; + // We keep the values encoded as in GCM, not Polyval, i.e. without reversing the bits. + // This is fine, but the reversed result would be shifted by 1 bit. So, we shift h + // to compensate. + const carry = ((@as(u128, 0xc2) << 120) | 1) & (@as(u128, 0) -% (h0 >> 127)); + const h = (h0 << 1) ^ carry; - return Ghash{ - .h0 = h0, - .h1 = h1, - .h2 = h2, - .h0r = h0r, - .h1r = h1r, - .h2r = h2r, - - .hh0 = hh0, - .hh1 = hh1, - .hh2 = hh2, - .hh0r = hh0r, - .hh1r = hh1r, - .hh2r = hh2r, - }; + var hx: [pc_count]Precomp = undefined; + hx[0] = h; + if (builtin.mode != .ReleaseSmall) { + if (block_count > 2) { + hx[1] = gcm_reduce(clsq128(hx[0])); // h^2 + } + if (block_count > 4) { + hx[2] = gcm_reduce(clmul128(hx[1], h)); // h^3 + hx[3] = gcm_reduce(clsq128(hx[1])); // h^4 + } + if (block_count > 8) { + hx[4] = gcm_reduce(clmul128(hx[3], h)); // h^5 + hx[5] = gcm_reduce(clmul128(hx[4], h)); // h^6 + hx[6] = gcm_reduce(clmul128(hx[5], h)); // h^7 + hx[7] = gcm_reduce(clsq128(hx[3])); // h^8 + } } + return Ghash{ .hx = hx }; } - inline fn clmul_pclmul(x: u64, y: u64) u64 { + /// Initialize the GHASH state with a key. + pub fn init(key: *const [key_length]u8) Ghash { + return Ghash.initForBlockCount(key, math.maxInt(usize)); + } + + // Carryless multiplication of two 64-bit integers for x86_64. + inline fn clmul_pclmul(x: u64, y: u64) u128 { const product = asm ( \\ vpclmulqdq $0x00, %[x], %[y], %[out] : [out] "=x" (-> @Vector(2, u64)), : [x] "x" (@bitCast(@Vector(2, u64), @as(u128, x))), [y] "x" (@bitCast(@Vector(2, u64), @as(u128, y))), ); - return product[0]; + return (@as(u128, product[1]) << 64) | product[0]; } - inline fn clmul_pmull(x: u64, y: u64) u64 { + // Carryless multiplication of two 64-bit integers for ARM crypto. + inline fn clmul_pmull(x: u64, y: u64) u128 { const product = asm ( \\ pmull %[out].1q, %[x].1d, %[y].1d : [out] "=w" (-> @Vector(2, u64)), : [x] "w" (@bitCast(@Vector(2, u64), @as(u128, x))), [y] "w" (@bitCast(@Vector(2, u64), @as(u128, y))), ); - return product[0]; + return (@as(u128, product[1]) << 64) | product[0]; } - fn clmul_soft(x: u64, y: u64) u64 { - const x0 = x & 0x1111111111111111; - const x1 = x & 0x2222222222222222; - const x2 = x & 0x4444444444444444; - const x3 = x & 0x8888888888888888; + // Software carryless multiplication of two 64-bit integers. + fn clmul_soft(x: u64, y: u64) u128 { + const x0 = x & 0x1111111111111110; + const x1 = x & 0x2222222222222220; + const x2 = x & 0x4444444444444440; + const x3 = x & 0x8888888888888880; const y0 = y & 0x1111111111111111; const y1 = y & 0x2222222222222222; const y2 = y & 0x4444444444444444; const y3 = y & 0x8888888888888888; - var z0 = (x0 *% y0) ^ (x1 *% y3) ^ (x2 *% y2) ^ (x3 *% y1); - var z1 = (x0 *% y1) ^ (x1 *% y0) ^ (x2 *% y3) ^ (x3 *% y2); - var z2 = (x0 *% y2) ^ (x1 *% y1) ^ (x2 *% y0) ^ (x3 *% y3); - var z3 = (x0 *% y3) ^ (x1 *% y2) ^ (x2 *% y1) ^ (x3 *% y0); - z0 &= 0x1111111111111111; - z1 &= 0x2222222222222222; - z2 &= 0x4444444444444444; - z3 &= 0x8888888888888888; - return z0 | z1 | z2 | z3; + const z0 = (x0 * @as(u128, y0)) ^ (x1 * @as(u128, y3)) ^ (x2 * @as(u128, y2)) ^ (x3 * @as(u128, y1)); + const z1 = (x0 * @as(u128, y1)) ^ (x1 * @as(u128, y0)) ^ (x2 * @as(u128, y3)) ^ (x3 * @as(u128, y2)); + const z2 = (x0 * @as(u128, y2)) ^ (x1 * @as(u128, y1)) ^ (x2 * @as(u128, y0)) ^ (x3 * @as(u128, y3)); + const z3 = (x0 * @as(u128, y3)) ^ (x1 * @as(u128, y2)) ^ (x2 * @as(u128, y1)) ^ (x3 * @as(u128, y0)); + + const x0_mask = @as(u64, 0) -% (x & 1); + const x1_mask = @as(u64, 0) -% ((x >> 1) & 1); + const x2_mask = @as(u64, 0) -% ((x >> 2) & 1); + const x3_mask = @as(u64, 0) -% ((x >> 3) & 1); + const extra = (x0_mask & y) ^ (@as(u128, x1_mask & y) << 1) ^ + (@as(u128, x2_mask & y) << 2) ^ (@as(u128, x3_mask & y) << 3); + + return (z0 & 0x11111111111111111111111111111111) ^ + (z1 & 0x22222222222222222222222222222222) ^ + (z2 & 0x44444444444444444444444444444444) ^ + (z3 & 0x88888888888888888888888888888888) ^ extra; + } + + // Square a 128-bit integer in GF(2^128). + fn clsq128(x: u128) u256 { + const lo = @truncate(u64, x); + const hi = @truncate(u64, x >> 64); + const mid = lo ^ hi; + const r_lo = clmul(lo, lo); + const r_hi = clmul(hi, hi); + const r_mid = clmul(mid, mid) ^ r_lo ^ r_hi; + return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo; + } + + // Multiply two 128-bit integers in GF(2^128). + inline fn clmul128(x: u128, y: u128) u256 { + const x_lo = @truncate(u64, x); + const x_hi = @truncate(u64, x >> 64); + const y_lo = @truncate(u64, y); + const y_hi = @truncate(u64, y >> 64); + const r_lo = clmul(x_lo, y_lo); + const r_hi = clmul(x_hi, y_hi); + const r_mid = clmul(x_lo ^ x_hi, y_lo ^ y_hi) ^ r_lo ^ r_hi; + return (@as(u256, r_hi) << 128) ^ (@as(u256, r_mid) << 64) ^ r_lo; + } + + // Reduce a 256-bit representative of a polynomial modulo the irreducible polynomial x^128 + x^127 + x^126 + x^121 + 1. + // This is done *without reversing the bits*, using Shay Gueron's black magic demysticated here: + // https://blog.quarkslab.com/reversing-a-finite-field-multiplication-optimization.html + inline fn gcm_reduce(x: u256) u128 { + const p64 = (((1 << 121) | (1 << 126) | (1 << 127)) >> 64); + const a = clmul(@truncate(u64, x), p64); + const b = ((@truncate(u128, x) << 64) | (@truncate(u128, x) >> 64)) ^ a; + const c = clmul(@truncate(u64, b), p64); + const d = ((b << 64) | (b >> 64)) ^ c; + return d ^ @truncate(u128, x >> 128); } const has_pclmul = std.Target.x86.featureSetHas(builtin.cpu.features, .pclmul); @@ -142,116 +157,100 @@ pub const Ghash = struct { break :impl clmul_soft; }; + // Process a block of 16 bytes. fn blocks(st: *Ghash, msg: []const u8) void { assert(msg.len % 16 == 0); // GHASH blocks() expects full blocks - var y1 = st.y1; - var y0 = st.y0; + var acc = st.acc; var i: usize = 0; - // 2-blocks aggregated reduction if (builtin.mode != .ReleaseSmall) { + // 8-blocks aggregated reduction + while (i + 128 <= msg.len) : (i += 128) { + const b0 = mem.readIntBig(u128, msg[i..][0..16]); + const z0 = acc ^ b0; + const z0h = clmul128(z0, st.hx[7]); + + const b1 = mem.readIntBig(u128, msg[i..][16..32]); + const b1h = clmul128(b1, st.hx[6]); + + const b2 = mem.readIntBig(u128, msg[i..][32..48]); + const b2h = clmul128(b2, st.hx[5]); + + const b3 = mem.readIntBig(u128, msg[i..][48..64]); + const b3h = clmul128(b3, st.hx[4]); + + const b4 = mem.readIntBig(u128, msg[i..][64..80]); + const b4h = clmul128(b4, st.hx[3]); + + const b5 = mem.readIntBig(u128, msg[i..][80..96]); + const b5h = clmul128(b5, st.hx[2]); + + const b6 = mem.readIntBig(u128, msg[i..][96..112]); + const b6h = clmul128(b6, st.hx[1]); + + const b7 = mem.readIntBig(u128, msg[i..][112..128]); + const b7h = clmul128(b7, st.hx[0]); + + const u = z0h ^ b1h ^ b2h ^ b3h ^ b4h ^ b5h ^ b6h ^ b7h; + acc = gcm_reduce(u); + } + + // 4-blocks aggregated reduction + while (i + 64 <= msg.len) : (i += 64) { + // (acc + b0) * H^4 unreduced + const b0 = mem.readIntBig(u128, msg[i..][0..16]); + const z0 = acc ^ b0; + const z0h = clmul128(z0, st.hx[3]); + + // b1 * H^3 unreduced + const b1 = mem.readIntBig(u128, msg[i..][16..32]); + const b1h = clmul128(b1, st.hx[2]); + + // b2 * H^2 unreduced + const b2 = mem.readIntBig(u128, msg[i..][32..48]); + const b2h = clmul128(b2, st.hx[1]); + + // b3 * H unreduced + const b3 = mem.readIntBig(u128, msg[i..][48..64]); + const b3h = clmul128(b3, st.hx[0]); + + // (((acc + b0) * H^4) + B1 * H^3 + B2 * H^2 + B3 * H) (mod P) + const u = z0h ^ b1h ^ b2h ^ b3h; + acc = gcm_reduce(u); + } + + // 2-blocks aggregated reduction while (i + 32 <= msg.len) : (i += 32) { - // B0 * H^2 unreduced - y1 ^= mem.readIntBig(u64, msg[i..][0..8]); - y0 ^= mem.readIntBig(u64, msg[i..][8..16]); + // (acc + b0) * H^2 unreduced + const b0 = mem.readIntBig(u128, msg[i..][0..16]); + const z0 = acc ^ b0; + const z0h = clmul128(z0, st.hx[1]); - const y1r = @bitReverse(y1); - const y0r = @bitReverse(y0); - const y2 = y0 ^ y1; - const y2r = y0r ^ y1r; + // b1 * H unreduced + const b1 = mem.readIntBig(u128, msg[i..][16..32]); + const b1h = clmul128(b1, st.hx[0]); - var z0 = clmul(y0, st.hh0); - var z1 = clmul(y1, st.hh1); - var z2 = clmul(y2, st.hh2) ^ z0 ^ z1; - var z0h = clmul(y0r, st.hh0r); - var z1h = clmul(y1r, st.hh1r); - var z2h = clmul(y2r, st.hh2r) ^ z0h ^ z1h; - - // B1 * H unreduced - const sy1 = mem.readIntBig(u64, msg[i..][16..24]); - const sy0 = mem.readIntBig(u64, msg[i..][24..32]); - - const sy1r = @bitReverse(sy1); - const sy0r = @bitReverse(sy0); - const sy2 = sy0 ^ sy1; - const sy2r = sy0r ^ sy1r; - - const sz0 = clmul(sy0, st.h0); - const sz1 = clmul(sy1, st.h1); - const sz2 = clmul(sy2, st.h2) ^ sz0 ^ sz1; - const sz0h = clmul(sy0r, st.h0r); - const sz1h = clmul(sy1r, st.h1r); - const sz2h = clmul(sy2r, st.h2r) ^ sz0h ^ sz1h; - - // ((B0 * H^2) + B1 * H) (mod M) - z0 ^= sz0; - z1 ^= sz1; - z2 ^= sz2; - z0h ^= sz0h; - z1h ^= sz1h; - z2h ^= sz2h; - z0h = @bitReverse(z0h) >> 1; - z1h = @bitReverse(z1h) >> 1; - z2h = @bitReverse(z2h) >> 1; - - var v3 = z1h; - var v2 = z1 ^ z2h; - var v1 = z0h ^ z2; - var v0 = z0; - - v3 = (v3 << 1) | (v2 >> 63); - v2 = (v2 << 1) | (v1 >> 63); - v1 = (v1 << 1) | (v0 >> 63); - v0 = (v0 << 1); - - v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7); - v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57); - y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7); - y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57); + // (((acc + b0) * H^2) + B1 * H) (mod P) + const u = z0h ^ b1h; + acc = gcm_reduce(u); } } // single block while (i + 16 <= msg.len) : (i += 16) { - y1 ^= mem.readIntBig(u64, msg[i..][0..8]); - y0 ^= mem.readIntBig(u64, msg[i..][8..16]); + // (acc + b0) * H unreduced + const b0 = mem.readIntBig(u128, msg[i..][0..16]); + const z0 = acc ^ b0; + const z0h = clmul128(z0, st.hx[0]); - const y1r = @bitReverse(y1); - const y0r = @bitReverse(y0); - const y2 = y0 ^ y1; - const y2r = y0r ^ y1r; - - const z0 = clmul(y0, st.h0); - const z1 = clmul(y1, st.h1); - var z2 = clmul(y2, st.h2) ^ z0 ^ z1; - var z0h = clmul(y0r, st.h0r); - var z1h = clmul(y1r, st.h1r); - var z2h = clmul(y2r, st.h2r) ^ z0h ^ z1h; - z0h = @bitReverse(z0h) >> 1; - z1h = @bitReverse(z1h) >> 1; - z2h = @bitReverse(z2h) >> 1; - - // shift & reduce - var v3 = z1h; - var v2 = z1 ^ z2h; - var v1 = z0h ^ z2; - var v0 = z0; - - v3 = (v3 << 1) | (v2 >> 63); - v2 = (v2 << 1) | (v1 >> 63); - v1 = (v1 << 1) | (v0 >> 63); - v0 = (v0 << 1); - - v2 ^= v0 ^ (v0 >> 1) ^ (v0 >> 2) ^ (v0 >> 7); - v1 ^= (v0 << 63) ^ (v0 << 62) ^ (v0 << 57); - y1 = v3 ^ v1 ^ (v1 >> 1) ^ (v1 >> 2) ^ (v1 >> 7); - y0 = v2 ^ (v1 << 63) ^ (v1 << 62) ^ (v1 << 57); + // (acc + b0) * H (mod P) + acc = gcm_reduce(z0h); } - st.y1 = y1; - st.y0 = y0; + st.acc = acc; } + /// Absorb a message into the GHASH state. pub fn update(st: *Ghash, m: []const u8) void { var mb = m; @@ -295,14 +294,15 @@ pub const Ghash = struct { st.leftover = 0; } + /// Compute the GHASH of the entire input. pub fn final(st: *Ghash, out: *[mac_length]u8) void { st.pad(); - mem.writeIntBig(u64, out[0..8], st.y1); - mem.writeIntBig(u64, out[8..16], st.y0); + mem.writeIntBig(u128, out[0..16], st.acc); utils.secureZero(u8, @ptrCast([*]u8, st)[0..@sizeOf(Ghash)]); } + /// Compute the GHASH of a message. pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [key_length]u8) void { var st = Ghash.init(key); st.update(msg);