From c2a779ae79facc0c6a102825723ae707fdbf8c19 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Thu, 31 Oct 2024 20:55:34 -0400 Subject: [PATCH 01/14] std.crypto.tls: implement TLSv1.2 --- lib/std/crypto/25519/ed25519.zig | 36 +- lib/std/crypto/Certificate.zig | 369 ++++--- lib/std/crypto/ecdsa.zig | 50 +- lib/std/crypto/tls.zig | 300 ++++-- lib/std/crypto/tls/Client.zig | 1565 +++++++++++++++++++----------- lib/std/http/protocol.zig | 24 +- 6 files changed, 1533 insertions(+), 811 deletions(-) diff --git a/lib/std/crypto/25519/ed25519.zig b/lib/std/crypto/25519/ed25519.zig index d7b51271d2..3620cfc4ba 100644 --- a/lib/std/crypto/25519/ed25519.zig +++ b/lib/std/crypto/25519/ed25519.zig @@ -151,7 +151,9 @@ pub const Ed25519 = struct { a: Curve, expected_r: Curve, - fn init(sig: Signature, public_key: PublicKey) (NonCanonicalError || EncodingError || IdentityElementError)!Verifier { + pub const InitError = NonCanonicalError || EncodingError || IdentityElementError; + + fn init(sig: Signature, public_key: PublicKey) InitError!Verifier { const r = sig.r; const s = sig.s; try Curve.scalar.rejectNonCanonical(s); @@ -173,8 +175,11 @@ pub const Ed25519 = struct { self.h.update(msg); } + pub const VerifyError = WeakPublicKeyError || IdentityElementError || + SignatureVerificationError; + /// Verify that the signature is valid for the entire message. - pub fn verify(self: *Verifier) (SignatureVerificationError || WeakPublicKeyError || IdentityElementError)!void { + pub fn verify(self: *Verifier) VerifyError!void { var hram64: [Sha512.digest_length]u8 = undefined; self.h.final(&hram64); const hram = Curve.scalar.reduce64(hram64); @@ -197,10 +202,10 @@ pub const Ed25519 = struct { s: CompressedScalar, /// Return the raw signature (r, s) in little-endian format. - pub fn toBytes(self: Signature) [encoded_length]u8 { + pub fn toBytes(sig: Signature) [encoded_length]u8 { var bytes: [encoded_length]u8 = undefined; - bytes[0..Curve.encoded_length].* = self.r; - bytes[Curve.encoded_length..].* = self.s; + bytes[0..Curve.encoded_length].* = sig.r; + bytes[Curve.encoded_length..].* = sig.s; return bytes; } @@ -214,17 +219,26 @@ pub const Ed25519 = struct { } /// Create a Verifier for incremental verification of a signature. - pub fn verifier(self: Signature, public_key: PublicKey) (NonCanonicalError || EncodingError || IdentityElementError)!Verifier { - return Verifier.init(self, public_key); + pub fn verifier(sig: Signature, public_key: PublicKey) Verifier.InitError!Verifier { + return Verifier.init(sig, public_key); } + pub const VerifyError = Verifier.InitError || Verifier.VerifyError; + /// Verify the signature against a message and public key. /// Return IdentityElement or NonCanonical if the public key or signature are not in the expected range, /// or SignatureVerificationError if the signature is invalid for the given message and key. - pub fn verify(self: Signature, msg: []const u8, public_key: PublicKey) (IdentityElementError || NonCanonicalError || SignatureVerificationError || EncodingError || WeakPublicKeyError)!void { - var st = try Verifier.init(self, public_key); - st.update(msg); - return st.verify(); + pub fn verify(sig: Signature, msg: []const u8, public_key: PublicKey) VerifyError!void { + try sig.concatVerify(&.{msg}, public_key); + } + + /// Verify the signature against a concatenated message and public key. + /// Return IdentityElement or NonCanonical if the public key or signature are not in the expected range, + /// or SignatureVerificationError if the signature is invalid for the given message and key. + pub fn concatVerify(sig: Signature, msg: []const []const u8, public_key: PublicKey) VerifyError!void { + var st = try Verifier.init(sig, public_key); + for (msg) |part| st.update(part); + try st.verify(); } }; diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 3580d11fcd..9ca6aa5e48 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -20,18 +20,18 @@ pub const Algorithm = enum { curveEd25519, pub const map = std.StaticStringMap(Algorithm).initComptime(.{ - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01 }, .ecdsa_with_SHA224 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02 }, .ecdsa_with_SHA256 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03 }, .ecdsa_with_SHA384 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04 }, .ecdsa_with_SHA512 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x02 }, .md2WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x04 }, .md5WithRSAEncryption }, - .{ &[_]u8{ 0x2B, 0x65, 0x70 }, .curveEd25519 }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01 }, .ecdsa_with_SHA224 }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02 }, .ecdsa_with_SHA256 }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03 }, .ecdsa_with_SHA384 }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04 }, .ecdsa_with_SHA512 }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x02 }, .md2WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x04 }, .md5WithRSAEncryption }, + .{ &.{ 0x2B, 0x65, 0x70 }, .curveEd25519 }, }); pub fn Hash(comptime algorithm: Algorithm) type { @@ -49,13 +49,15 @@ pub const Algorithm = enum { pub const AlgorithmCategory = enum { rsaEncryption, + rsassa_pss, X9_62_id_ecPublicKey, curveEd25519, pub const map = std.StaticStringMap(AlgorithmCategory).initComptime(.{ - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey }, - .{ &[_]u8{ 0x2B, 0x65, 0x70 }, .curveEd25519 }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0A }, .rsassa_pss }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey }, + .{ &.{ 0x2B, 0x65, 0x70 }, .curveEd25519 }, }); }; @@ -74,18 +76,18 @@ pub const Attribute = enum { domainComponent, pub const map = std.StaticStringMap(Attribute).initComptime(.{ - .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, - .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber }, - .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, - .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, - .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, - .{ &[_]u8{ 0x55, 0x04, 0x09 }, .streetAddress }, - .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, - .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, - .{ &[_]u8{ 0x55, 0x04, 0x11 }, .postalCode }, - .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x09, 0x01 }, .pkcs9_emailAddress }, - .{ &[_]u8{ 0x09, 0x92, 0x26, 0x89, 0x93, 0xF2, 0x2C, 0x64, 0x01, 0x19 }, .domainComponent }, + .{ &.{ 0x55, 0x04, 0x03 }, .commonName }, + .{ &.{ 0x55, 0x04, 0x05 }, .serialNumber }, + .{ &.{ 0x55, 0x04, 0x06 }, .countryName }, + .{ &.{ 0x55, 0x04, 0x07 }, .localityName }, + .{ &.{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, + .{ &.{ 0x55, 0x04, 0x09 }, .streetAddress }, + .{ &.{ 0x55, 0x04, 0x0A }, .organizationName }, + .{ &.{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, + .{ &.{ 0x55, 0x04, 0x11 }, .postalCode }, + .{ &.{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x09, 0x01 }, .pkcs9_emailAddress }, + .{ &.{ 0x09, 0x92, 0x26, 0x89, 0x93, 0xF2, 0x2C, 0x64, 0x01, 0x19 }, .domainComponent }, }); }; @@ -95,9 +97,9 @@ pub const NamedCurve = enum { X9_62_prime256v1, pub const map = std.StaticStringMap(NamedCurve).initComptime(.{ - .{ &[_]u8{ 0x2B, 0x81, 0x04, 0x00, 0x22 }, .secp384r1 }, - .{ &[_]u8{ 0x2B, 0x81, 0x04, 0x00, 0x23 }, .secp521r1 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07 }, .X9_62_prime256v1 }, + .{ &.{ 0x2B, 0x81, 0x04, 0x00, 0x22 }, .secp384r1 }, + .{ &.{ 0x2B, 0x81, 0x04, 0x00, 0x23 }, .secp521r1 }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07 }, .X9_62_prime256v1 }, }); pub fn Curve(comptime curve: NamedCurve) type { @@ -131,28 +133,28 @@ pub const ExtensionId = enum { netscape_comment, pub const map = std.StaticStringMap(ExtensionId).initComptime(.{ - .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, - .{ &[_]u8{ 0x55, 0x1D, 0x01 }, .authority_key_identifier }, - .{ &[_]u8{ 0x55, 0x1D, 0x07 }, .subject_alt_name }, - .{ &[_]u8{ 0x55, 0x1D, 0x0E }, .subject_key_identifier }, - .{ &[_]u8{ 0x55, 0x1D, 0x0F }, .key_usage }, - .{ &[_]u8{ 0x55, 0x1D, 0x0A }, .basic_constraints }, - .{ &[_]u8{ 0x55, 0x1D, 0x10 }, .private_key_usage_period }, - .{ &[_]u8{ 0x55, 0x1D, 0x11 }, .subject_alt_name }, - .{ &[_]u8{ 0x55, 0x1D, 0x12 }, .issuer_alt_name }, - .{ &[_]u8{ 0x55, 0x1D, 0x13 }, .basic_constraints }, - .{ &[_]u8{ 0x55, 0x1D, 0x14 }, .crl_number }, - .{ &[_]u8{ 0x55, 0x1D, 0x1F }, .crl_distribution_points }, - .{ &[_]u8{ 0x55, 0x1D, 0x20 }, .certificate_policies }, - .{ &[_]u8{ 0x55, 0x1D, 0x23 }, .authority_key_identifier }, - .{ &[_]u8{ 0x55, 0x1D, 0x25 }, .ext_key_usage }, - .{ &[_]u8{ 0x2B, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x01 }, .msCertsrvCAVersion }, - .{ &[_]u8{ 0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x01 }, .info_access }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF6, 0x7D, 0x07, 0x41, 0x00 }, .entrustVersInfo }, - .{ &[_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x14, 0x02 }, .enroll_certtype }, - .{ &[_]u8{ 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x0c }, .pe_logotype }, - .{ &[_]u8{ 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x01 }, .netscape_cert_type }, - .{ &[_]u8{ 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x0d }, .netscape_comment }, + .{ &.{ 0x55, 0x04, 0x03 }, .commonName }, + .{ &.{ 0x55, 0x1D, 0x01 }, .authority_key_identifier }, + .{ &.{ 0x55, 0x1D, 0x07 }, .subject_alt_name }, + .{ &.{ 0x55, 0x1D, 0x0E }, .subject_key_identifier }, + .{ &.{ 0x55, 0x1D, 0x0F }, .key_usage }, + .{ &.{ 0x55, 0x1D, 0x0A }, .basic_constraints }, + .{ &.{ 0x55, 0x1D, 0x10 }, .private_key_usage_period }, + .{ &.{ 0x55, 0x1D, 0x11 }, .subject_alt_name }, + .{ &.{ 0x55, 0x1D, 0x12 }, .issuer_alt_name }, + .{ &.{ 0x55, 0x1D, 0x13 }, .basic_constraints }, + .{ &.{ 0x55, 0x1D, 0x14 }, .crl_number }, + .{ &.{ 0x55, 0x1D, 0x1F }, .crl_distribution_points }, + .{ &.{ 0x55, 0x1D, 0x20 }, .certificate_policies }, + .{ &.{ 0x55, 0x1D, 0x23 }, .authority_key_identifier }, + .{ &.{ 0x55, 0x1D, 0x25 }, .ext_key_usage }, + .{ &.{ 0x2B, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x01 }, .msCertsrvCAVersion }, + .{ &.{ 0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x01 }, .info_access }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF6, 0x7D, 0x07, 0x41, 0x00 }, .entrustVersInfo }, + .{ &.{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x14, 0x02 }, .enroll_certtype }, + .{ &.{ 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x0c }, .pe_logotype }, + .{ &.{ 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x01 }, .netscape_cert_type }, + .{ &.{ 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x0d }, .netscape_comment }, }); }; @@ -185,6 +187,7 @@ pub const Parsed = struct { pub const PubKeyAlgo = union(AlgorithmCategory) { rsaEncryption: void, + rsassa_pss: void, X9_62_id_ecPublicKey: NamedCurve, curveEd25519: void, }; @@ -386,7 +389,7 @@ test "Parsed.checkHostName" { try expectEqual(true, Parsed.checkHostName("bar.ziglang.org", "*.Ziglang.ORG")); } -pub const ParseError = der.Element.ParseElementError || ParseVersionError || ParseTimeError || ParseEnumError || ParseBitStringError; +pub const ParseError = der.Element.ParseError || ParseVersionError || ParseTimeError || ParseEnumError || ParseBitStringError; pub fn parse(cert: Certificate) ParseError!Parsed { const cert_bytes = cert.buffer; @@ -413,13 +416,9 @@ pub fn parse(cert: Certificate) ParseError!Parsed { const pub_key_info = try der.Element.parse(cert_bytes, subject.slice.end); const pub_key_signature_algorithm = try der.Element.parse(cert_bytes, pub_key_info.slice.start); const pub_key_algo_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.start); - const pub_key_algo_tag = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem); - var pub_key_algo: Parsed.PubKeyAlgo = undefined; - switch (pub_key_algo_tag) { - .rsaEncryption => { - pub_key_algo = .{ .rsaEncryption = {} }; - }, - .X9_62_id_ecPublicKey => { + const pub_key_algo: Parsed.PubKeyAlgo = switch (try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem)) { + inline else => |tag| @unionInit(Parsed.PubKeyAlgo, @tagName(tag), {}), + .X9_62_id_ecPublicKey => pub_key_algo: { // RFC 5480 Section 2.1.1.1 Named Curve // ECParameters ::= CHOICE { // namedCurve OBJECT IDENTIFIER @@ -428,12 +427,9 @@ pub fn parse(cert: Certificate) ParseError!Parsed { // } const params_elem = try der.Element.parse(cert_bytes, pub_key_algo_elem.slice.end); const named_curve = try parseNamedCurve(cert_bytes, params_elem); - pub_key_algo = .{ .X9_62_id_ecPublicKey = named_curve }; + break :pub_key_algo .{ .X9_62_id_ecPublicKey = named_curve }; }, - .curveEd25519 => { - pub_key_algo = .{ .curveEd25519 = {} }; - }, - } + }; const pub_key_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.end); const pub_key = try parseBitString(cert, pub_key_elem); @@ -731,7 +727,7 @@ pub fn parseVersion(bytes: []const u8, version_elem: der.Element) ParseVersionEr fn verifyRsa( comptime Hash: type, - message: []const u8, + msg: []const u8, sig: []const u8, pub_key_algo: Parsed.PubKeyAlgo, pub_key: []const u8, @@ -743,59 +739,14 @@ fn verifyRsa( if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; - const hash_der = switch (Hash) { - crypto.hash.Sha1 => [_]u8{ - 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, - 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, - }, - crypto.hash.sha2.Sha224 => [_]u8{ - 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, - 0x00, 0x04, 0x1c, - }, - crypto.hash.sha2.Sha256 => [_]u8{ - 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, - 0x00, 0x04, 0x20, - }, - crypto.hash.sha2.Sha384 => [_]u8{ - 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, - 0x00, 0x04, 0x30, - }, - crypto.hash.sha2.Sha512 => [_]u8{ - 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, - 0x00, 0x04, 0x40, - }, - else => @compileError("unreachable"), - }; - - var msg_hashed: [Hash.digest_length]u8 = undefined; - Hash.hash(message, &msg_hashed, .{}); - switch (modulus.len) { inline 128, 256, 384, 512 => |modulus_len| { - const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; - const em: [modulus_len]u8 = - [2]u8{ 0, 1 } ++ - ([1]u8{0xff} ** ps_len) ++ - [1]u8{0} ++ - hash_der ++ - msg_hashed; - - const public_key = rsa.PublicKey.fromBytes(exponent, modulus) catch return error.CertificateSignatureInvalid; - const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key) catch |err| switch (err) { - error.MessageTooLong => unreachable, - }; - - if (!mem.eql(u8, &em, &em_dec)) { + const public_key = rsa.PublicKey.fromBytes(exponent, modulus) catch + return error.CertificateSignatureInvalid; + rsa.PKCS1v1_5Signature.verify(modulus_len, sig[0..modulus_len].*, msg, public_key, Hash) catch return error.CertificateSignatureInvalid; - } - }, - else => { - return error.CertificateSignatureUnsupportedBitCount; }, + else => return error.CertificateSignatureUnsupportedBitCount, } } @@ -908,9 +859,9 @@ pub const der = struct { pub const empty: Slice = .{ .start = 0, .end = 0 }; }; - pub const ParseElementError = error{CertificateFieldHasInvalidLength}; + pub const ParseError = error{CertificateFieldHasInvalidLength}; - pub fn parse(bytes: []const u8, index: u32) ParseElementError!Element { + pub fn parse(bytes: []const u8, index: u32) Element.ParseError!Element { var i = index; const identifier = @as(Identifier, @bitCast(bytes[i])); i += 1; @@ -958,21 +909,41 @@ pub const rsa = struct { const Modulus = std.crypto.ff.Modulus(max_modulus_bits); const Fe = Modulus.Fe; + /// RFC 3447 8.1 RSASSA-PSS pub const PSSSignature = struct { pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 { - var result = [1]u8{0} ** modulus_len; - std.mem.copyForwards(u8, &result, msg); + var result: [modulus_len]u8 = undefined; + @memcpy(result[0..msg.len], msg); + @memset(result[msg.len..], 0); return result; } - pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type) !void { + pub const VerifyError = EncryptError || error{InvalidSignature}; + + pub fn verify( + comptime modulus_len: usize, + sig: [modulus_len]u8, + msg: []const u8, + public_key: PublicKey, + comptime Hash: type, + ) VerifyError!void { + try concatVerify(modulus_len, sig, &.{msg}, public_key, Hash); + } + + pub fn concatVerify( + comptime modulus_len: usize, + sig: [modulus_len]u8, + msg: []const []const u8, + public_key: PublicKey, + comptime Hash: type, + ) VerifyError!void { const mod_bits = public_key.n.bits(); const em_dec = try encrypt(modulus_len, sig, public_key); - EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash) catch unreachable; + try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash); } - fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) !void { + fn EMSA_PSS_VERIFY(msg: []const []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) VerifyError!void { // 1. If the length of M is greater than the input limitation for // the hash function (2^61 - 1 octets for SHA-1), output // "inconsistent" and stop. @@ -986,7 +957,11 @@ pub const rsa = struct { // 2. Let mHash = Hash(M), an octet string of length hLen. var mHash: [Hash.digest_length]u8 = undefined; - Hash.hash(msg, &mHash, .{}); + { + var hasher: Hash = .init(.{}); + for (msg) |part| hasher.update(part); + hasher.final(&mHash); + } // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. if (emLen < Hash.digest_length + sLen + 2) { @@ -1082,25 +1057,14 @@ pub const rsa = struct { } fn MGF1(comptime Hash: type, out: []u8, seed: *const [Hash.digest_length]u8, len: usize) ![]u8 { - var counter: usize = 0; + var counter: u32 = 0; var idx: usize = 0; - var c: [4]u8 = undefined; - var hash: [Hash.digest_length + c.len]u8 = undefined; - @memcpy(hash[0..Hash.digest_length], seed); - var hashed: [Hash.digest_length]u8 = undefined; + var hash = seed.* ++ @as([4]u8, undefined); while (idx < len) { - c[0] = @as(u8, @intCast((counter >> 24) & 0xFF)); - c[1] = @as(u8, @intCast((counter >> 16) & 0xFF)); - c[2] = @as(u8, @intCast((counter >> 8) & 0xFF)); - c[3] = @as(u8, @intCast(counter & 0xFF)); - - std.mem.copyForwards(u8, hash[seed.len..], &c); - Hash.hash(&hash, &hashed, .{}); - - std.mem.copyForwards(u8, out[idx..], &hashed); - idx += hashed.len; - + std.mem.writeInt(u32, hash[seed.len..][0..4], counter, .big); + Hash.hash(&hash, out[idx..][0..Hash.digest_length], .{}); + idx += Hash.digest_length; counter += 1; } @@ -1108,11 +1072,128 @@ pub const rsa = struct { } }; + /// RFC 3447 8.2 RSASSA-PKCS1-v1_5 + pub const PKCS1v1_5Signature = struct { + pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 { + var result: [modulus_len]u8 = undefined; + @memcpy(result[0..msg.len], msg); + @memset(result[msg.len..], 0); + return result; + } + + pub const VerifyError = EncryptError || error{InvalidSignature}; + + pub fn verify( + comptime modulus_len: usize, + sig: [modulus_len]u8, + msg: []const u8, + public_key: PublicKey, + comptime Hash: type, + ) VerifyError!void { + try concatVerify(modulus_len, sig, &.{msg}, public_key, Hash); + } + + pub fn concatVerify( + comptime modulus_len: usize, + sig: [modulus_len]u8, + msg: []const []const u8, + public_key: PublicKey, + comptime Hash: type, + ) VerifyError!void { + const em_dec = try encrypt(modulus_len, sig, public_key); + const em = try EMSA_PKCS1_V1_5_ENCODE(msg, modulus_len, Hash); + if (!std.mem.eql(u8, &em_dec, &em)) return error.InvalidSignature; + } + + fn EMSA_PKCS1_V1_5_ENCODE(msg: []const []const u8, comptime emLen: usize, comptime Hash: type) VerifyError![emLen]u8 { + comptime var em_index = emLen; + var em: [emLen]u8 = undefined; + + // 1. Apply the hash function to the message M to produce a hash value + // H: + // + // H = Hash(M). + // + // If the hash function outputs "message too long," output "message + // too long" and stop. + var hasher: Hash = .init(.{}); + for (msg) |part| hasher.update(part); + em_index -= Hash.digest_length; + hasher.final(em[em_index..]); + + // 2. Encode the algorithm ID for the hash function and the hash value + // into an ASN.1 value of type DigestInfo (see Appendix A.2.4) with + // the Distinguished Encoding Rules (DER), where the type DigestInfo + // has the syntax + // + // DigestInfo ::= SEQUENCE { + // digestAlgorithm AlgorithmIdentifier, + // digest OCTET STRING + // } + // + // The first field identifies the hash function and the second + // contains the hash value. Let T be the DER encoding of the + // DigestInfo value (see the notes below) and let tLen be the length + // in octets of T. + const hash_der: []const u8 = &switch (Hash) { + crypto.hash.Sha1 => .{ + 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, + 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, + }, + crypto.hash.sha2.Sha224 => .{ + 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, + 0x00, 0x04, 0x1c, + }, + crypto.hash.sha2.Sha256 => .{ + 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, + 0x00, 0x04, 0x20, + }, + crypto.hash.sha2.Sha384 => .{ + 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, + 0x00, 0x04, 0x30, + }, + crypto.hash.sha2.Sha512 => .{ + 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, + 0x00, 0x04, 0x40, + }, + else => @compileError("unreachable"), + }; + em_index -= hash_der.len; + @memcpy(em[em_index..][0..hash_der.len], hash_der); + + // 3. If emLen < tLen + 11, output "intended encoded message length too + // short" and stop. + + // 4. Generate an octet string PS consisting of emLen - tLen - 3 octets + // with hexadecimal value 0xff. The length of PS will be at least 8 + // octets. + em_index -= 1; + @memset(em[2..em_index], 0xff); + + // 5. Concatenate PS, the DER encoding T, and other padding to form the + // encoded message EM as + // + // EM = 0x00 || 0x01 || PS || 0x00 || T. + em[em_index] = 0x00; + em[1] = 0x01; + em[0] = 0x00; + + // 6. Output EM. + return em; + } + }; + pub const PublicKey = struct { n: Modulus, e: Fe, - pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8) !PublicKey { + pub const FromBytesError = error{CertificatePublicKeyInvalid}; + + pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8) FromBytesError!PublicKey { // Reject modulus below 512 bits. // 512-bit RSA was factored in 1999, so this limit barely means anything, // but establish some limit now to ratchet in what we can. @@ -1137,7 +1218,9 @@ pub const rsa = struct { }; } - pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } { + pub const ParseDerError = der.Element.ParseError || error{CertificateFieldHasWrongDataType}; + + pub fn parseDer(pub_key: []const u8) ParseDerError!struct { modulus: []const u8, exponent: []const u8 } { const pub_key_seq = try der.Element.parse(pub_key, 0); if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); @@ -1156,7 +1239,9 @@ pub const rsa = struct { } }; - fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey) ![modulus_len]u8 { + const EncryptError = error{MessageTooLong}; + + fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey) EncryptError![modulus_len]u8 { const m = Fe.fromBytes(public_key.n, &msg, .big) catch return error.MessageTooLong; const e = public_key.n.powPublic(m, public_key.e) catch unreachable; var res: [modulus_len]u8 = undefined; diff --git a/lib/std/crypto/ecdsa.zig b/lib/std/crypto/ecdsa.zig index 50f13a010d..a015178f3d 100644 --- a/lib/std/crypto/ecdsa.zig +++ b/lib/std/crypto/ecdsa.zig @@ -91,24 +91,33 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { s: Curve.scalar.CompressedScalar, /// Create a Verifier for incremental verification of a signature. - pub fn verifier(self: Signature, public_key: PublicKey) (NonCanonicalError || EncodingError || IdentityElementError)!Verifier { - return Verifier.init(self, public_key); + pub fn verifier(sig: Signature, public_key: PublicKey) Verifier.InitError!Verifier { + return Verifier.init(sig, public_key); } + pub const VerifyError = Verifier.InitError || Verifier.VerifyError; + /// Verify the signature against a message and public key. /// Return IdentityElement or NonCanonical if the public key or signature are not in the expected range, /// or SignatureVerificationError if the signature is invalid for the given message and key. - pub fn verify(self: Signature, msg: []const u8, public_key: PublicKey) (IdentityElementError || NonCanonicalError || SignatureVerificationError)!void { - var st = try Verifier.init(self, public_key); - st.update(msg); - return st.verify(); + pub fn verify(sig: Signature, msg: []const u8, public_key: PublicKey) VerifyError!void { + try sig.concatVerify(&.{msg}, public_key); + } + + /// Verify the signature against a concatenated message and public key. + /// Return IdentityElement or NonCanonical if the public key or signature are not in the expected range, + /// or SignatureVerificationError if the signature is invalid for the given message and key. + pub fn concatVerify(sig: Signature, msg: []const []const u8, public_key: PublicKey) VerifyError!void { + var st = try Verifier.init(sig, public_key); + for (msg) |part| st.update(part); + try st.verify(); } /// Return the raw signature (r, s) in big-endian format. - pub fn toBytes(self: Signature) [encoded_length]u8 { + pub fn toBytes(sig: Signature) [encoded_length]u8 { var bytes: [encoded_length]u8 = undefined; - @memcpy(bytes[0 .. encoded_length / 2], &self.r); - @memcpy(bytes[encoded_length / 2 ..], &self.s); + @memcpy(bytes[0 .. encoded_length / 2], &sig.r); + @memcpy(bytes[encoded_length / 2 ..], &sig.s); return bytes; } @@ -124,23 +133,23 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { /// Encode the signature using the DER format. /// The maximum length of the DER encoding is der_encoded_length_max. /// The function returns a slice, that can be shorter than der_encoded_length_max. - pub fn toDer(self: Signature, buf: *[der_encoded_length_max]u8) []u8 { + pub fn toDer(sig: Signature, buf: *[der_encoded_length_max]u8) []u8 { var fb = io.fixedBufferStream(buf); const w = fb.writer(); - const r_len = @as(u8, @intCast(self.r.len + (self.r[0] >> 7))); - const s_len = @as(u8, @intCast(self.s.len + (self.s[0] >> 7))); + const r_len = @as(u8, @intCast(sig.r.len + (sig.r[0] >> 7))); + const s_len = @as(u8, @intCast(sig.s.len + (sig.s[0] >> 7))); const seq_len = @as(u8, @intCast(2 + r_len + 2 + s_len)); w.writeAll(&[_]u8{ 0x30, seq_len }) catch unreachable; w.writeAll(&[_]u8{ 0x02, r_len }) catch unreachable; - if (self.r[0] >> 7 != 0) { + if (sig.r[0] >> 7 != 0) { w.writeByte(0x00) catch unreachable; } - w.writeAll(&self.r) catch unreachable; + w.writeAll(&sig.r) catch unreachable; w.writeAll(&[_]u8{ 0x02, s_len }) catch unreachable; - if (self.s[0] >> 7 != 0) { + if (sig.s[0] >> 7 != 0) { w.writeByte(0x00) catch unreachable; } - w.writeAll(&self.s) catch unreachable; + w.writeAll(&sig.s) catch unreachable; return fb.getWritten(); } @@ -236,7 +245,9 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { s: Curve.scalar.Scalar, public_key: PublicKey, - fn init(sig: Signature, public_key: PublicKey) (IdentityElementError || NonCanonicalError)!Verifier { + pub const InitError = IdentityElementError || NonCanonicalError; + + fn init(sig: Signature, public_key: PublicKey) InitError!Verifier { const r = try Curve.scalar.Scalar.fromBytes(sig.r, .big); const s = try Curve.scalar.Scalar.fromBytes(sig.s, .big); if (r.isZero() or s.isZero()) return error.IdentityElement; @@ -254,8 +265,11 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { self.h.update(data); } + pub const VerifyError = IdentityElementError || NonCanonicalError || + SignatureVerificationError; + /// Verify that the signature is valid for the entire message. - pub fn verify(self: *Verifier) (IdentityElementError || NonCanonicalError || SignatureVerificationError)!void { + pub fn verify(self: *Verifier) VerifyError!void { const ht = Curve.scalar.encoded_length; const h_len = @max(Hash.digest_length, ht); var h: [h_len]u8 = [_]u8{0} ** h_len; diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index fb1b550e42..7732f3b74e 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -54,6 +54,8 @@ pub const close_notify_alert = [_]u8{ }; pub const ProtocolVersion = enum(u16) { + tls_1_0 = 0x0301, + tls_1_1 = 0x0302, tls_1_2 = 0x0303, tls_1_3 = 0x0304, _, @@ -69,14 +71,18 @@ pub const ContentType = enum(u8) { }; pub const HandshakeType = enum(u8) { + hello_request = 0, client_hello = 1, server_hello = 2, new_session_ticket = 4, end_of_early_data = 5, encrypted_extensions = 8, certificate = 11, + server_key_exchange = 12, certificate_request = 13, + server_hello_done = 14, certificate_verify = 15, + client_key_exchange = 16, finished = 20, key_update = 24, message_hash = 254, @@ -198,36 +204,36 @@ pub const AlertDescription = enum(u8) { _, pub fn toError(alert: AlertDescription) Error!void { - return switch (alert) { + switch (alert) { .close_notify => {}, // not an error - .unexpected_message => error.TlsAlertUnexpectedMessage, - .bad_record_mac => error.TlsAlertBadRecordMac, - .record_overflow => error.TlsAlertRecordOverflow, - .handshake_failure => error.TlsAlertHandshakeFailure, - .bad_certificate => error.TlsAlertBadCertificate, - .unsupported_certificate => error.TlsAlertUnsupportedCertificate, - .certificate_revoked => error.TlsAlertCertificateRevoked, - .certificate_expired => error.TlsAlertCertificateExpired, - .certificate_unknown => error.TlsAlertCertificateUnknown, - .illegal_parameter => error.TlsAlertIllegalParameter, - .unknown_ca => error.TlsAlertUnknownCa, - .access_denied => error.TlsAlertAccessDenied, - .decode_error => error.TlsAlertDecodeError, - .decrypt_error => error.TlsAlertDecryptError, - .protocol_version => error.TlsAlertProtocolVersion, - .insufficient_security => error.TlsAlertInsufficientSecurity, - .internal_error => error.TlsAlertInternalError, - .inappropriate_fallback => error.TlsAlertInappropriateFallback, + .unexpected_message => return error.TlsAlertUnexpectedMessage, + .bad_record_mac => return error.TlsAlertBadRecordMac, + .record_overflow => return error.TlsAlertRecordOverflow, + .handshake_failure => return error.TlsAlertHandshakeFailure, + .bad_certificate => return error.TlsAlertBadCertificate, + .unsupported_certificate => return error.TlsAlertUnsupportedCertificate, + .certificate_revoked => return error.TlsAlertCertificateRevoked, + .certificate_expired => return error.TlsAlertCertificateExpired, + .certificate_unknown => return error.TlsAlertCertificateUnknown, + .illegal_parameter => return error.TlsAlertIllegalParameter, + .unknown_ca => return error.TlsAlertUnknownCa, + .access_denied => return error.TlsAlertAccessDenied, + .decode_error => return error.TlsAlertDecodeError, + .decrypt_error => return error.TlsAlertDecryptError, + .protocol_version => return error.TlsAlertProtocolVersion, + .insufficient_security => return error.TlsAlertInsufficientSecurity, + .internal_error => return error.TlsAlertInternalError, + .inappropriate_fallback => return error.TlsAlertInappropriateFallback, .user_canceled => {}, // not an error - .missing_extension => error.TlsAlertMissingExtension, - .unsupported_extension => error.TlsAlertUnsupportedExtension, - .unrecognized_name => error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, - .certificate_required => error.TlsAlertCertificateRequired, - .no_application_protocol => error.TlsAlertNoApplicationProtocol, - _ => error.TlsAlertUnknown, - }; + .missing_extension => return error.TlsAlertMissingExtension, + .unsupported_extension => return error.TlsAlertUnsupportedExtension, + .unrecognized_name => return error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity, + .certificate_required => return error.TlsAlertCertificateRequired, + .no_application_protocol => return error.TlsAlertNoApplicationProtocol, + _ => return error.TlsAlertUnknown, + } } }; @@ -286,6 +292,20 @@ pub const NamedGroup = enum(u16) { }; pub const CipherSuite = enum(u16) { + RSA_WITH_AES_128_CBC_SHA = 0x002F, + DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033, + RSA_WITH_AES_256_CBC_SHA = 0x0035, + DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039, + RSA_WITH_AES_128_CBC_SHA256 = 0x003C, + RSA_WITH_AES_256_CBC_SHA256 = 0x003D, + DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067, + DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B, + RSA_WITH_AES_128_GCM_SHA256 = 0x009C, + RSA_WITH_AES_256_GCM_SHA384 = 0x009D, + DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E, + DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F, + EMPTY_RENEGOTIATION_INFO_SCSV = 0x00FF, + AES_128_GCM_SHA256 = 0x1301, AES_256_GCM_SHA384 = 0x1302, CHACHA20_POLY1305_SHA256 = 0x1303, @@ -293,7 +313,98 @@ pub const CipherSuite = enum(u16) { AES_128_CCM_8_SHA256 = 0x1305, AEGIS_256_SHA512 = 0x1306, AEGIS_128L_SHA256 = 0x1307, + + ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009, + ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A, + ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013, + ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014, + ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023, + ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024, + ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027, + ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028, + ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B, + ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C, + ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F, + ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030, + + ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8, + ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9, + DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCAA, + _, + + pub const With = enum { + AES_128_CBC_SHA, + AES_256_CBC_SHA, + AES_128_CBC_SHA256, + AES_256_CBC_SHA256, + AES_256_CBC_SHA384, + + AES_128_GCM_SHA256, + AES_256_GCM_SHA384, + + CHACHA20_POLY1305_SHA256, + + AES_128_CCM_SHA256, + AES_128_CCM_8_SHA256, + + AEGIS_256_SHA512, + AEGIS_128L_SHA256, + }; + + pub fn with(cipher_suite: CipherSuite) With { + return switch (cipher_suite) { + .RSA_WITH_AES_128_CBC_SHA, + .DHE_RSA_WITH_AES_128_CBC_SHA, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + => .AES_128_CBC_SHA, + .RSA_WITH_AES_256_CBC_SHA, + .DHE_RSA_WITH_AES_256_CBC_SHA, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + .ECDHE_RSA_WITH_AES_256_CBC_SHA, + => .AES_256_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA256, + .DHE_RSA_WITH_AES_128_CBC_SHA256, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + => .AES_128_CBC_SHA256, + .RSA_WITH_AES_256_CBC_SHA256, + .DHE_RSA_WITH_AES_256_CBC_SHA256, + => .AES_256_CBC_SHA256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + => .AES_256_CBC_SHA384, + + .RSA_WITH_AES_128_GCM_SHA256, + .DHE_RSA_WITH_AES_128_GCM_SHA256, + .AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + => .AES_128_GCM_SHA256, + .RSA_WITH_AES_256_GCM_SHA384, + .DHE_RSA_WITH_AES_256_GCM_SHA384, + .AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + => .AES_256_GCM_SHA384, + + .CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .DHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => .CHACHA20_POLY1305_SHA256, + + .AES_128_CCM_SHA256 => .AES_128_CCM_SHA256, + .AES_128_CCM_8_SHA256 => .AES_128_CCM_8_SHA256, + + .AEGIS_256_SHA512 => .AEGIS_256_SHA512, + .AEGIS_128L_SHA256 => .AEGIS_128L_SHA256, + + .EMPTY_RENEGOTIATION_INFO_SCSV => unreachable, + _ => unreachable, + }; + } }; pub const CertificateType = enum(u8) { @@ -308,58 +419,108 @@ pub const KeyUpdateRequest = enum(u8) { _, }; -pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type { +pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type, comptime explicit_iv_length: comptime_int) type { return struct { - pub const AEAD = AeadType; - pub const Hash = HashType; - pub const Hmac = crypto.auth.hmac.Hmac(Hash); - pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); + pub const A = ApplicationCipherT(AeadType, HashType, explicit_iv_length); - handshake_secret: [Hkdf.prk_length]u8, - master_secret: [Hkdf.prk_length]u8, - client_handshake_key: [AEAD.key_length]u8, - server_handshake_key: [AEAD.key_length]u8, - client_finished_key: [Hmac.key_length]u8, - server_finished_key: [Hmac.key_length]u8, - client_handshake_iv: [AEAD.nonce_length]u8, - server_handshake_iv: [AEAD.nonce_length]u8, - transcript_hash: Hash, + transcript_hash: A.Hash, + version: union { + tls_1_2: struct { + server_verify_data: [12]u8, + app_cipher: A.Tls_1_2, + }, + tls_1_3: struct { + handshake_secret: [A.Hkdf.prk_length]u8, + master_secret: [A.Hkdf.prk_length]u8, + client_handshake_key: [A.AEAD.key_length]u8, + server_handshake_key: [A.AEAD.key_length]u8, + client_finished_key: [A.Hmac.key_length]u8, + server_finished_key: [A.Hmac.key_length]u8, + client_handshake_iv: [A.AEAD.nonce_length]u8, + server_handshake_iv: [A.AEAD.nonce_length]u8, + }, + }, }; } pub const HandshakeCipher = union(enum) { - AES_128_GCM_SHA256: HandshakeCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), - AES_256_GCM_SHA384: HandshakeCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), - CHACHA20_POLY1305_SHA256: HandshakeCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), - AEGIS_256_SHA512: HandshakeCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512), - AEGIS_128L_SHA256: HandshakeCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), + AES_128_GCM_SHA256: HandshakeCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256, 8), + AES_256_GCM_SHA384: HandshakeCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384, 8), + CHACHA20_POLY1305_SHA256: HandshakeCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256, 0), + AEGIS_256_SHA512: HandshakeCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512, 0), + AEGIS_128L_SHA256: HandshakeCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256, 0), }; -pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type { - return struct { +pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type, comptime explicit_iv_length: comptime_int) type { + return union { pub const AEAD = AeadType; pub const Hash = HashType; pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - client_secret: [Hash.digest_length]u8, - server_secret: [Hash.digest_length]u8, - client_key: [AEAD.key_length]u8, - server_key: [AEAD.key_length]u8, - client_iv: [AEAD.nonce_length]u8, - server_iv: [AEAD.nonce_length]u8, + pub const enc_key_length = AEAD.key_length; + pub const fixed_iv_length = AEAD.nonce_length - explicit_iv_length; + pub const record_iv_length = explicit_iv_length; + pub const mac_length = AEAD.tag_length; + pub const mac_key_length = Hmac.key_length_min; + + tls_1_2: Tls_1_2, + tls_1_3: Tls_1_3, + + pub const Tls_1_2 = extern struct { + client_write_MAC_key: [mac_key_length]u8, + server_write_MAC_key: [mac_key_length]u8, + client_write_key: [enc_key_length]u8, + server_write_key: [enc_key_length]u8, + client_write_IV: [fixed_iv_length]u8, + server_write_IV: [fixed_iv_length]u8, + // non-standard entropy + client_salt: [record_iv_length]u8, + }; + + pub const Tls_1_3 = struct { + client_secret: [Hash.digest_length]u8, + server_secret: [Hash.digest_length]u8, + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + }; }; } /// Encryption parameters for application traffic. pub const ApplicationCipher = union(enum) { - AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), - AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), - CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), - AEGIS_256_SHA512: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512), - AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), + AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256, 8), + AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384, 8), + CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256, 0), + AEGIS_256_SHA512: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512, 0), + AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256, 0), }; +pub fn hmacExpandLabel( + comptime Hmac: type, + secret: []const u8, + label_then_seed: []const []const u8, + comptime len: usize, +) [len]u8 { + const initial_hmac: Hmac = .init(secret); + var a: [Hmac.mac_length]u8 = undefined; + var result: [std.mem.alignForwardAnyAlign(usize, len, Hmac.mac_length)]u8 = undefined; + var index: usize = 0; + while (index < result.len) : (index += Hmac.mac_length) { + var a_hmac = initial_hmac; + if (index > 0) a_hmac.update(&a) else for (label_then_seed) |part| a_hmac.update(part); + a_hmac.final(&a); + + var result_hmac = initial_hmac; + result_hmac.update(&a); + for (label_then_seed) |part| result_hmac.update(part); + result_hmac.final(result[index..][0..Hmac.mac_length]); + } + return result[0..len].*; +} + pub fn hkdfExpandLabel( comptime Hkdf: type, key: [Hkdf.prk_length]u8, @@ -418,19 +579,16 @@ pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeO return array(2, result); } -pub inline fn int2(x: u16) [2]u8 { - return .{ - @as(u8, @truncate(x >> 8)), - @as(u8, @truncate(x)), - }; +pub inline fn int2(int: u16) [2]u8 { + var arr: [2]u8 = undefined; + std.mem.writeInt(u16, &arr, int, .big); + return arr; } -pub inline fn int3(x: u24) [3]u8 { - return .{ - @as(u8, @truncate(x >> 16)), - @as(u8, @truncate(x >> 8)), - @as(u8, @truncate(x)), - }; +pub inline fn int3(int: u24) [3]u8 { + var arr: [3]u8 = undefined; + std.mem.writeInt(u24, &arr, int, .big); + return arr; } /// An abstraction to ensure that protocol-parsing code does not perform an diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 84dbb2167a..c69c6ee936 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -8,12 +8,14 @@ const assert = std.debug.assert; const Certificate = std.crypto.Certificate; const max_ciphertext_len = tls.max_ciphertext_len; +const hmacExpandLabel = tls.hmacExpandLabel; const hkdfExpandLabel = tls.hkdfExpandLabel; const int2 = tls.int2; const int3 = tls.int3; const array = tls.array; const enum_array = tls.enum_array; +tls_version: tls.ProtocolVersion, read_seq: u64, write_seq: u64, /// The starting index of cleartext bytes inside `partially_read_buffer`. @@ -136,7 +138,7 @@ pub fn InitError(comptime Stream: type) type { }; } -/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which +/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `stream`, which /// must conform to `StreamInterface`. /// /// `host` is only borrowed during this function call. @@ -145,26 +147,20 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In var random_buffer: [128]u8 = undefined; crypto.random.bytes(&random_buffer); - const hello_rand = random_buffer[0..32].*; + const client_hello_rand = random_buffer[0..32].*; + var server_hello_rand: [32]u8 = undefined; const legacy_session_id = random_buffer[32..64].*; - const x25519_kp_seed = random_buffer[64..96].*; - const secp256r1_kp_seed = random_buffer[96..128].*; - const x25519_kp = crypto.dh.X25519.KeyPair.create(x25519_kp_seed) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. + var key_share = KeyShare.init(random_buffer[64..128].*) catch |err| switch (err) { + // Only possible to happen if the seed is all zeroes. error.IdentityElement => return error.InsufficientEntropy, }; - const secp256r1_kp = crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair.create(secp256r1_kp_seed) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. - error.IdentityElement => return error.InsufficientEntropy, - }; - const ml_kem768_kp = crypto.kem.ml_kem.MLKem768.KeyPair.create(null) catch {}; const extensions_payload = - tls.extension(.supported_versions, [_]u8{ - 0x02, // byte length of supported versions - 0x03, 0x04, // TLS 1.3 - }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ + tls.extension(.supported_versions, [_]u8{2 + 2} ++ // byte length of supported versions + int2(@intFromEnum(tls.ProtocolVersion.tls_1_3)) ++ + int2(@intFromEnum(tls.ProtocolVersion.tls_1_2))) ++ + tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384, .rsa_pss_rsae_sha256, @@ -178,11 +174,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In })) ++ tls.extension( .key_share, array(1, int2(@intFromEnum(tls.NamedGroup.x25519)) ++ - array(1, x25519_kp.public_key) ++ + array(1, key_share.x25519_kp.public_key) ++ int2(@intFromEnum(tls.NamedGroup.secp256r1)) ++ - array(1, secp256r1_kp.public_key.toUncompressedSec1()) ++ + array(1, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++ int2(@intFromEnum(tls.NamedGroup.x25519_ml_kem768)) ++ - array(1, x25519_kp.public_key ++ ml_kem768_kp.public_key.toBytes())), + array(1, key_share.x25519_kp.public_key ++ key_share.ml_kem768_kp.public_key.toBytes())), ) ++ int2(@intFromEnum(tls.ExtensionType.server_name)) ++ int2(host_len + 5) ++ // byte length of this extension payload @@ -198,7 +194,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const client_hello = int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - hello_rand ++ + client_hello_rand ++ [1]u8{32} ++ legacy_session_id ++ cipher_suites ++ int2(legacy_compression_methods) ++ @@ -209,16 +205,16 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In int3(@intCast(client_hello.len + host_len)) ++ client_hello; - const plaintext_header = [_]u8{ - @intFromEnum(tls.ContentType.handshake), - 0x03, 0x01, // legacy_record_version - } ++ int2(@intCast(out_handshake.len + host_len)) ++ out_handshake; + const cleartext_header = [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ + int2(@intFromEnum(tls.ProtocolVersion.tls_1_0)) ++ // legacy_record_version + int2(@intCast(out_handshake.len + host_len)) ++ + out_handshake; { var iovecs = [_]std.posix.iovec_const{ .{ - .base = &plaintext_header, - .len = plaintext_header.len, + .base = &cleartext_header, + .len = cleartext_header.len, }, .{ .base = host.ptr, @@ -228,8 +224,10 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In try stream.writevAll(&iovecs); } - const client_hello_bytes1 = plaintext_header[5..]; + const client_hello_bytes1 = cleartext_header[tls.record_header_len..]; + var tls_version: tls.ProtocolVersion = undefined; + var cipher_suite_tag: tls.CipherSuite = undefined; var handshake_cipher: tls.HandshakeCipher = undefined; var handshake_buffer: [8000]u8 = undefined; var d: tls.Decoder = .{ .buf = &handshake_buffer }; @@ -259,10 +257,10 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; const length = ptd.decode(u24); var hsd = try ptd.sub(length); - try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2); + try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1); const legacy_version = hsd.decode(u16); - const random = hsd.array(32); - if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) { + @memcpy(&server_hello_rand, hsd.array(32)); + if (mem.eql(u8, &server_hello_rand, &tls.hello_retry_request_sequence)) { // This is a HelloRetryRequest message. This client implementation // does not expect to get one. return error.TlsUnexpectedMessage; @@ -270,83 +268,44 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const legacy_session_id_echo_len = hsd.decode(u8); if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; const legacy_session_id_echo = hsd.array(32); - if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) - return error.TlsIllegalParameter; - const cipher_suite_tag = hsd.decode(tls.CipherSuite); + cipher_suite_tag = hsd.decode(tls.CipherSuite); hsd.skip(1); // legacy_compression_method - const extensions_size = hsd.decode(u16); - var all_extd = try hsd.sub(extensions_size); - var supported_version: u16 = 0; - var shared_key: []const u8 = undefined; - var have_shared_key = false; - while (!all_extd.eof()) { - try all_extd.ensure(2 + 2); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - var extd = try all_extd.sub(ext_size); - switch (et) { - .supported_versions => { - if (supported_version != 0) return error.TlsIllegalParameter; - try extd.ensure(2); - supported_version = extd.decode(u16); - }, - .key_share => { - if (have_shared_key) return error.TlsIllegalParameter; - have_shared_key = true; - try extd.ensure(4); - const named_group = extd.decode(tls.NamedGroup); - const key_size = extd.decode(u16); - try extd.ensure(key_size); - switch (named_group) { - .x25519_ml_kem768 => { - const xksl = crypto.dh.X25519.public_length; - const hksl = xksl + crypto.kem.ml_kem.MLKem768.ciphertext_length; - if (key_size != hksl) - return error.TlsIllegalParameter; - const server_ks = extd.array(hksl); - - shared_key = &((crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - server_ks[0..xksl].*, - ) catch return error.TlsDecryptFailure) ++ (ml_kem768_kp.secret_key.decaps( - server_ks[xksl..hksl], - ) catch return error.TlsDecryptFailure)); - }, - .x25519 => { - const ksl = crypto.dh.X25519.public_length; - if (key_size != ksl) return error.TlsIllegalParameter; - const server_pub_key = extd.array(ksl); - - shared_key = &(crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - server_pub_key.*, - ) catch return error.TlsDecryptFailure); - }, - .secp256r1 => { - const server_pub_key = extd.slice(key_size); - - const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; - const pk = PublicKey.fromSec1(server_pub_key) catch { - return error.TlsDecryptFailure; - }; - const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .big) catch { - return error.TlsDecryptFailure; - }; - shared_key = &mul.affineCoordinates().x.toBytes(.big); - }, - else => { - return error.TlsIllegalParameter; - }, - } - }, - else => {}, + var supported_version: ?u16 = null; + if (!hsd.eof()) { + try hsd.ensure(2); + const extensions_size = hsd.decode(u16); + var all_extd = try hsd.sub(extensions_size); + while (!all_extd.eof()) { + try all_extd.ensure(2 + 2); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + var extd = try all_extd.sub(ext_size); + switch (et) { + .supported_versions => { + if (supported_version) |_| return error.TlsIllegalParameter; + try extd.ensure(2); + supported_version = extd.decode(u16); + }, + .key_share => { + if (key_share.getSharedSecret()) |_| return error.TlsIllegalParameter; + try extd.ensure(4); + const named_group = extd.decode(tls.NamedGroup); + const key_size = extd.decode(u16); + try extd.ensure(key_size); + try key_share.exchange(named_group, extd.slice(key_size)); + }, + else => {}, + } } } - if (!have_shared_key) return error.TlsIllegalParameter; - const tls_version = if (supported_version == 0) legacy_version else supported_version; - if (tls_version != @intFromEnum(tls.ProtocolVersion.tls_1_3)) - return error.TlsIllegalParameter; + tls_version = @enumFromInt(supported_version orelse legacy_version); + switch (tls_version) { + .tls_1_3 => if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter, + .tls_1_2 => if (mem.eql(u8, server_hello_rand[24..31], "DOWNGRD") and + server_hello_rand[31] >> 1 == 0x00) return error.TlsIllegalParameter, + else => return error.TlsIllegalParameter, + } switch (cipher_suite_tag) { inline .AES_128_GCM_SHA256, @@ -354,43 +313,63 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In .CHACHA20_POLY1305_SHA256, .AEGIS_256_SHA512, .AEGIS_128L_SHA256, + + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, => |tag| { - const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag)); - handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{ - .handshake_secret = undefined, - .master_secret = undefined, - .client_handshake_key = undefined, - .server_handshake_key = undefined, - .client_finished_key = undefined, - .server_finished_key = undefined, - .client_handshake_iv = undefined, - .server_handshake_iv = undefined, - .transcript_hash = P.Hash.init(.{}), + handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag.with()), .{ + .transcript_hash = .init(.{}), + .version = undefined, }); - const p = &@field(handshake_cipher, @tagName(tag)); + const p = &@field(handshake_cipher, @tagName(tag.with())); p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 p.transcript_hash.update(host); // Client Hello part 2 p.transcript_hash.update(server_hello_fragment); - const hello_hash = p.transcript_hash.peek(); - const zeroes = [1]u8{0} ** P.Hash.digest_length; - const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = tls.emptyHash(P.Hash); - const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); - p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, shared_key); - const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); - p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); - const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); - p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); - p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); }, - else => { - return error.TlsIllegalParameter; + + else => return error.TlsIllegalParameter, + } + switch (tls_version) { + .tls_1_3 => switch (cipher_suite_tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_256_SHA512, + .AEGIS_128L_SHA256, + => |tag| { + const sk = key_share.getSharedSecret() orelse return error.TlsIllegalParameter; + const p = &@field(handshake_cipher, @tagName(tag.with())); + const P = @TypeOf(p.*).A; + const hello_hash = p.transcript_hash.peek(); + const zeroes = [1]u8{0} ** P.Hash.digest_length; + const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = tls.emptyHash(P.Hash); + p.version = .{ .tls_1_3 = undefined }; + const pv = &p.version.tls_1_3; + const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); + pv.handshake_secret = P.Hkdf.extract(&hs_derived_secret, sk); + const ap_derived_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); + pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); + const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); + pv.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); + pv.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); + pv.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + pv.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + pv.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + pv.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, + else => return error.TlsIllegalParameter, }, + .tls_1_2 => switch (cipher_suite_tag) { + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => {}, + else => return error.TlsIllegalParameter, + }, + else => return error.TlsIllegalParameter, } }, else => return error.TlsUnexpectedMessage, @@ -404,58 +383,74 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In // the previous certificate in memory so that it can be verified by the // next one. var cert_index: usize = 0; + var write_seq: u64 = 0; var read_seq: u64 = 0; var prev_cert: Certificate.Parsed = undefined; - // Set to true once a trust chain has been established from the first - // certificate to a root CA. + const CipherState = enum { + /// No cipher is in use + cleartext, + /// Handshake cipher is in use + handshake, + /// Application cipher is in use + application, + }; + var pending_cipher_state: CipherState = switch (tls_version) { + .tls_1_3 => .handshake, + .tls_1_2 => .cleartext, + else => unreachable, + }; + var cipher_state: CipherState = .cleartext; const HandshakeState = enum { /// In this state we expect only an encrypted_extensions message. encrypted_extensions, - /// In this state we expect certificate messages. + /// In this state we expect certificate handshake messages. certificate, /// In this state we expect certificate or certificate_verify messages. /// certificate messages are ignored since the trust chain is already /// established. trust_chain_established, - /// In this state, we expect only the finished message. + /// In this state, we expect only the server_hello_done handshake message. + server_hello_done, + /// In this state, we expect only the finished handshake message. finished, }; - var handshake_state: HandshakeState = .encrypted_extensions; + var handshake_state: HandshakeState = switch (tls_version) { + .tls_1_3 => .encrypted_extensions, + .tls_1_2 => .certificate, + else => unreachable, + }; var cleartext_bufs: [2][8000]u8 = undefined; - var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; - var main_cert_pub_key_buf: [600]u8 = undefined; - var main_cert_pub_key_len: u16 = undefined; + var main_cert_pub_key: CertificatePublicKey = undefined; const now_sec = std.time.timestamp(); while (true) { try d.readAtLeastOurAmt(stream, tls.record_header_len); - const record_header = d.buf[d.idx..][0..5]; - const ct = d.decode(tls.ContentType); + const record_header = d.buf[d.idx..][0..tls.record_header_len]; + const record_ct = d.decode(tls.ContentType); d.skip(2); // legacy_version const record_len = d.decode(u16); try d.readAtLeast(stream, record_len); var record_decoder = try d.sub(record_len); - switch (ct) { - .change_cipher_spec => { - try record_decoder.ensure(1); - if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter; - }, - .application_data => { + var ctd, const ct = content: switch (cipher_state) { + .cleartext => .{ record_decoder, record_ct }, + .handshake => { + std.debug.assert(tls_version == .tls_1_3); + if (record_ct != .application_data) return error.TlsUnexpectedMessage; + try record_decoder.ensure(record_len); const cleartext_buf = &cleartext_bufs[cert_index % 2]; - - const cleartext = switch (handshake_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const ciphertext_len = record_len - P.AEAD.tag_length; - try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length); - const ciphertext = record_decoder.slice(ciphertext_len); + const cleartext = cleartext: switch (handshake_cipher) { + inline else => |*p| { + const pv = &p.version.tls_1_3; + const P = @TypeOf(p.*).A; + if (record_len < P.AEAD.tag_length) return error.TlsRecordOverflow; + const ciphertext = record_decoder.slice(record_len - P.AEAD.tag_length); if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; const cleartext = cleartext_buf[0..ciphertext.len]; const auth_tag = record_decoder.array(P.AEAD.tag_length).*; const nonce = if (builtin.zig_backend == .stage2_x86_64 and P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) nonce: { - var nonce = p.server_handshake_iv; + var nonce = pv.server_handshake_iv; const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ read_seq, .big); break :nonce nonce; @@ -463,200 +458,320 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const V = @Vector(P.AEAD.nonce_length, u8); const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @as([8]u8, @bitCast(big(read_seq))); - break :nonce @as(V, p.server_handshake_iv) ^ operand; + break :nonce @as(V, pv.server_handshake_iv) ^ operand; }; - read_seq += 1; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch return error.TlsBadRecordMac; - break :c @constCast(mem.trimRight(u8, cleartext, "\x00")); + break :cleartext mem.trimRight(u8, cleartext, "\x00"); }, }; + read_seq += 1; + const ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); + if (ct != .handshake) return error.TlsUnexpectedMessage; + break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext[0 .. cleartext.len - 1])), ct }; + }, + .application => { + std.debug.assert(tls_version == .tls_1_2); + if (record_ct != .handshake) return error.TlsUnexpectedMessage; + try record_decoder.ensure(record_len); + const cleartext_buf = &cleartext_bufs[cert_index % 2]; + const cleartext = cleartext: switch (handshake_cipher) { + inline else => |*p| { + const pv = &p.version.tls_1_2; + const P = @TypeOf(p.*).A; + if (record_len < P.record_iv_length + P.mac_length) return error.TlsRecordOverflow; + const message_len: u16 = record_len - P.record_iv_length - P.mac_length; + if (message_len > cleartext_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_buf[0..message_len]; + const ad = std.mem.toBytes(big(read_seq)) ++ + record_header[0 .. 1 + 2] ++ + std.mem.toBytes(big(message_len)); + const record_iv = record_decoder.array(P.record_iv_length).*; + const masked_read_seq = read_seq & + comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); + const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.app_cipher.server_write_IV ++ record_iv; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ masked_read_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); + break :nonce @as(V, pv.app_cipher.server_write_IV ++ record_iv) ^ operand; + }; + const ciphertext = record_decoder.slice(message_len); + const auth_tag = record_decoder.array(P.mac_length); + P.AEAD.decrypt(cleartext, ciphertext, auth_tag.*, ad, nonce, pv.app_cipher.server_write_key) catch return error.TlsBadRecordMac; + break :cleartext cleartext; + }, + }; + read_seq += 1; + break :content .{ tls.Decoder.fromTheirSlice(cleartext), record_ct }; + }, + }; + switch (ct) { + .alert => { + try ctd.ensure(2); + const level = ctd.decode(tls.AlertLevel); + const desc = ctd.decode(tls.AlertDescription); + _ = level; - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - if (inner_ct != .handshake) return error.TlsUnexpectedMessage; - - var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); - while (true) { - try ctd.ensure(4); - const handshake_type = ctd.decode(tls.HandshakeType); - const handshake_len = ctd.decode(u24); - var hsd = try ctd.sub(handshake_len); - const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; - const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; - switch (handshake_type) { - .encrypted_extensions => { - if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; - handshake_state = .certificate; - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), + // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake + try desc.toError(); + // TODO: handle server-side closures + return error.TlsUnexpectedMessage; + }, + .change_cipher_spec => { + try ctd.ensure(1); + if (ctd.decode(u8) != 0x01) return error.TlsIllegalParameter; + cipher_state = pending_cipher_state; + }, + .handshake => while (true) { + try ctd.ensure(4); + const handshake_type = ctd.decode(tls.HandshakeType); + const handshake_len = ctd.decode(u24); + var hsd = try ctd.sub(handshake_len); + const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; + switch (handshake_type) { + .encrypted_extensions => { + if (tls_version != .tls_1_3) return error.TlsUnexpectedMessage; + if (cipher_state != .handshake) return error.TlsUnexpectedMessage; + if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; + handshake_state = .certificate; + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + try hsd.ensure(2); + const total_ext_size = hsd.decode(u16); + var all_extd = try hsd.sub(total_ext_size); + while (!all_extd.eof()) { + try all_extd.ensure(4); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + const extd = try all_extd.sub(ext_size); + _ = extd; + switch (et) { + .server_name => {}, + else => {}, } - try hsd.ensure(2); - const total_ext_size = hsd.decode(u16); - var all_extd = try hsd.sub(total_ext_size); - while (!all_extd.eof()) { - try all_extd.ensure(4); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - const extd = try all_extd.sub(ext_size); - _ = extd; - switch (et) { - .server_name => {}, - else => {}, - } + } + }, + .certificate => cert: { + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + switch (handshake_state) { + .certificate => {}, + .trust_chain_established => break :cert, + else => return error.TlsUnexpectedMessage, + } + + switch (tls_version) { + .tls_1_3 => { + try hsd.ensure(1 + 3); + const cert_req_ctx_len = hsd.decode(u8); + if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; + }, + .tls_1_2 => try hsd.ensure(3), + else => unreachable, + } + const certs_size = hsd.decode(u24); + var certs_decoder = try hsd.sub(certs_size); + while (!certs_decoder.eof()) { + try certs_decoder.ensure(3); + const cert_size = certs_decoder.decode(u24); + const certd = try certs_decoder.sub(cert_size); + + const subject_cert: Certificate = .{ + .buffer = certd.buf, + .index = @intCast(certd.idx), + }; + const subject = try subject_cert.parse(); + if (cert_index == 0) { + // Verify the host on the first certificate. + try subject.verifyHostName(host); + + // Keep track of the public key for the + // certificate_verify message later. + try main_cert_pub_key.init(subject.pub_key_algo, subject.pubKey()); + } else { + try prev_cert.verify(subject, now_sec); } - }, - .certificate => cert: { - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), + + if (ca_bundle.verify(subject, now_sec)) |_| { + handshake_state = .trust_chain_established; + break :cert; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + else => |e| return e, } - switch (handshake_state) { - .certificate => {}, - .trust_chain_established => break :cert, - else => return error.TlsUnexpectedMessage, - } - try hsd.ensure(1 + 4); - const cert_req_ctx_len = hsd.decode(u8); - if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; - const certs_size = hsd.decode(u24); - var certs_decoder = try hsd.sub(certs_size); - while (!certs_decoder.eof()) { - try certs_decoder.ensure(3); - const cert_size = certs_decoder.decode(u24); - const certd = try certs_decoder.sub(cert_size); - const subject_cert: Certificate = .{ - .buffer = certd.buf, - .index = @intCast(certd.idx), - }; - const subject = try subject_cert.parse(); - if (cert_index == 0) { - // Verify the host on the first certificate. - try subject.verifyHostName(host); - - // Keep track of the public key for the - // certificate_verify message later. - main_cert_pub_key_algo = subject.pub_key_algo; - const pub_key = subject.pubKey(); - if (pub_key.len > main_cert_pub_key_buf.len) - return error.CertificatePublicKeyInvalid; - @memcpy(main_cert_pub_key_buf[0..pub_key.len], pub_key); - main_cert_pub_key_len = @intCast(pub_key.len); - } else { - try prev_cert.verify(subject, now_sec); - } - - if (ca_bundle.verify(subject, now_sec)) |_| { - handshake_state = .trust_chain_established; - break :cert; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - else => |e| return e, - } - - prev_cert = subject; - cert_index += 1; + prev_cert = subject; + cert_index += 1; + if (tls_version == .tls_1_3) { try certs_decoder.ensure(2); const total_ext_size = certs_decoder.decode(u16); const all_extd = try certs_decoder.sub(total_ext_size); _ = all_extd; } - }, - .certificate_verify => { - switch (handshake_state) { - .trust_chain_established => handshake_state = .finished, - .certificate => return error.TlsCertificateNotVerified, - else => return error.TlsUnexpectedMessage, - } + } + }, + .server_key_exchange => { + if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage; + if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; + switch (handshake_state) { + .trust_chain_established => handshake_state = .server_hello_done, + .certificate => return error.TlsCertificateNotVerified, + else => return error.TlsUnexpectedMessage, + } - try hsd.ensure(4); - const scheme = hsd.decode(tls.SignatureScheme); - const sig_len = hsd.decode(u16); - try hsd.ensure(sig_len); - const encoded_sig = hsd.slice(sig_len); - const max_digest_len = 64; - var verify_buffer: [64 + 34 + max_digest_len]u8 = - ([1]u8{0x20} ** 64) ++ - "TLS 1.3, server CertificateVerify\x00".* ++ - @as([max_digest_len]u8, undefined); + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + try hsd.ensure(1 + 2 + 1); + const curve_type = hsd.decode(u8); + if (curve_type != 0x03) return error.TlsIllegalParameter; // named_curve + const named_group = hsd.decode(tls.NamedGroup); + if (named_group != .secp256r1) return error.TlsIllegalParameter; + const key_size = hsd.decode(u8); + try hsd.ensure(key_size); + const server_pub_key = hsd.slice(key_size); + try main_cert_pub_key.verifySignature(&hsd, &.{ &client_hello_rand, &server_hello_rand, hsd.buf[0..hsd.idx] }); + try key_share.exchange(named_group, server_pub_key); + }, + .server_hello_done => { + if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage; + if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; + if (handshake_state != .server_hello_done) return error.TlsUnexpectedMessage; + handshake_state = .finished; - const verify_bytes = switch (handshake_cipher) { - inline else => |*p| v: { - const transcript_digest = p.transcript_hash.peek(); - verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; - p.transcript_hash.update(wrapped_handshake); - break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; - }, - }; - const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - - switch (scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) - return error.TlsBadSignatureScheme; - const Ecdsa = SchemeEcdsa(comptime_scheme); - const sig = try Ecdsa.Signature.fromDer(encoded_sig); - const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); - try sig.verify(verify_bytes, key); - }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .rsaEncryption) - return error.TlsBadSignatureScheme; - - const Hash = SchemeHash(comptime_scheme); - const rsa = Certificate.rsa; - const components = try rsa.PublicKey.parseDer(main_cert_pub_key); - const exponent = components.exponent; - const modulus = components.modulus; - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const key = try rsa.PublicKey.fromBytes(exponent, modulus); - const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); - }, - else => { - return error.TlsBadRsaSignatureBitCount; - }, - } - }, - inline .ed25519 => |comptime_scheme| { - if (main_cert_pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; - const Eddsa = SchemeEddsa(comptime_scheme); - if (encoded_sig.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; - const sig = Eddsa.Signature.fromBytes(encoded_sig[0..Eddsa.Signature.encoded_length].*); - if (main_cert_pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; - const key = try Eddsa.PublicKey.fromBytes(main_cert_pub_key[0..Eddsa.PublicKey.encoded_length].*); - try sig.verify(verify_bytes, key); - }, - else => { - return error.TlsBadSignatureScheme; - }, - } - }, - .finished => { - if (handshake_state != .finished) return error.TlsUnexpectedMessage; - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = [_]u8{ - @intFromEnum(tls.ContentType.change_cipher_spec), - 0x03, 0x03, // legacy protocol version - 0x00, 0x01, // length - 0x01, - }; - const app_cipher = switch (handshake_cipher) { - inline else => |*p, tag| c: { - const P = @TypeOf(p.*); + const client_key_exchange_msg = + [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ // record content type + int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version + int2(0x46) ++ // record length + .{@intFromEnum(tls.HandshakeType.client_key_exchange)} ++ // handshake type + int3(0x42) ++ // params length + .{0x41} ++ // pubkey length + key_share.secp256r1_kp.public_key.toUncompressedSec1(); + // This message is to trick buggy proxies into behaving correctly. + const client_change_cipher_spec_msg = + [_]u8{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ // record content type + int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version + int2(1) ++ // record length + .{0x01}; + const pre_master_secret = key_share.getSharedSecret().?; + switch (handshake_cipher) { + inline else => |*p| { + const P = @TypeOf(p.*).A; + p.transcript_hash.update(wrapped_handshake); + p.transcript_hash.update(client_key_exchange_msg[tls.record_header_len..]); + const master_secret = hmacExpandLabel(P.Hmac, pre_master_secret, &.{ + "master secret", + &client_hello_rand, + &server_hello_rand, + }, 48); + const key_block = hmacExpandLabel( + P.Hmac, + &master_secret, + &.{ "key expansion", &server_hello_rand, &client_hello_rand }, + @sizeOf(P.Tls_1_2), + ); + const verify_data_len = 12; + const client_verify_cleartext = + [_]u8{@intFromEnum(tls.HandshakeType.finished)} ++ // handshake type + int3(verify_data_len) ++ // verify data length + hmacExpandLabel(P.Hmac, &master_secret, &.{ "client finished", &p.transcript_hash.peek() }, verify_data_len); + p.transcript_hash.update(&client_verify_cleartext); + p.version = .{ .tls_1_2 = .{ + .server_verify_data = hmacExpandLabel( + P.Hmac, + &master_secret, + &.{ "server finished", &p.transcript_hash.finalResult() }, + verify_data_len, + ), + .app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block), + } }; + const pv = &p.version.tls_1_2; + pending_cipher_state = .application; + const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.app_cipher.client_write_IV ++ pv.app_cipher.client_salt; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ write_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(write_seq))); + break :nonce @as(V, pv.app_cipher.client_write_IV ++ pv.app_cipher.client_salt) ^ operand; + }; + var client_verify_msg = [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ // record content type + int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version + int2(P.record_iv_length + client_verify_cleartext.len + P.mac_length) ++ // record length + nonce[P.fixed_iv_length..].* ++ + @as([client_verify_cleartext.len + P.mac_length]u8, undefined); + P.AEAD.encrypt( + client_verify_msg[client_verify_msg.len - P.mac_length - + client_verify_cleartext.len ..][0..client_verify_cleartext.len], + client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length], + &client_verify_cleartext, + std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int2(client_verify_cleartext.len), + nonce, + pv.app_cipher.client_write_key, + ); + const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg; + var all_msgs_vec = [_]std.posix.iovec_const{.{ + .base = &all_msgs, + .len = all_msgs.len, + }}; + try stream.writevAll(&all_msgs_vec); + }, + } + write_seq += 1; + }, + .certificate_verify => { + if (tls_version != .tls_1_3) return error.TlsUnexpectedMessage; + if (cipher_state != .handshake) return error.TlsUnexpectedMessage; + switch (handshake_state) { + .trust_chain_established => handshake_state = .finished, + .certificate => return error.TlsCertificateNotVerified, + else => return error.TlsUnexpectedMessage, + } + switch (handshake_cipher) { + inline else => |*p| { + try main_cert_pub_key.verifySignature(&hsd, &.{ + " " ** 64 ++ "TLS 1.3, server CertificateVerify\x00", + &p.transcript_hash.peek(), + }); + p.transcript_hash.update(wrapped_handshake); + }, + } + }, + .finished => { + if (cipher_state == .cleartext) return error.TlsUnexpectedMessage; + if (handshake_state != .finished) return error.TlsUnexpectedMessage; + // This message is to trick buggy proxies into behaving correctly. + const client_change_cipher_spec_msg = + [_]u8{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ + int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version + int2(1) ++ // length + .{0x01}; + const app_cipher = app_cipher: switch (handshake_cipher) { + inline else => |*p, tag| switch (tls_version) { + .tls_1_3 => { + const pv = &p.version.tls_1_3; + const P = @TypeOf(p.*).A; const finished_digest = p.transcript_hash.peek(); p.transcript_hash.update(wrapped_handshake); - const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); - if (!mem.eql(u8, &expected_server_verify_data, handshake)) - return error.TlsDecryptError; + const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, pv.server_finished_key); + if (!mem.eql(u8, &expected_server_verify_data, hsd.buf)) return error.TlsDecryptError; const handshake_hash = p.transcript_hash.finalResult(); - const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); + const verify_data = tls.hmac(P.Hmac, &handshake_hash, pv.client_finished_key); const out_cleartext = [_]u8{ @intFromEnum(tls.HandshakeType.finished), 0, 0, verify_data.len, // length @@ -664,67 +779,78 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - var finished_msg = [_]u8{ - @intFromEnum(tls.ContentType.application_data), - 0x03, 0x03, // legacy protocol version - 0, wrapped_len, // byte length of encrypted record - } ++ @as([wrapped_len]u8, undefined); + var finished_msg = [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ + int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version + int2(wrapped_len) ++ // byte length of encrypted record + @as([wrapped_len]u8, undefined); - const ad = finished_msg[0..5]; - const ciphertext = finished_msg[5..][0..out_cleartext.len]; + const ad = finished_msg[0..tls.record_header_len]; + const ciphertext = finished_msg[tls.record_header_len..][0..out_cleartext.len]; const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; - const nonce = p.client_handshake_iv; - P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); + const nonce = pv.client_handshake_iv; + P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key); - const both_msgs = client_change_cipher_spec_msg ++ finished_msg; - var both_msgs_vec = [_]std.posix.iovec_const{.{ - .base = &both_msgs, - .len = both_msgs.len, + const all_msgs = client_change_cipher_spec_msg ++ finished_msg; + var all_msgs_vec = [_]std.posix.iovec_const{.{ + .base = &all_msgs, + .len = all_msgs.len, }}; - try stream.writevAll(&both_msgs_vec); + try stream.writevAll(&all_msgs_vec); - const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ + const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + break :app_cipher @unionInit(tls.ApplicationCipher, @tagName(tag), .{ .tls_1_3 = .{ .client_secret = client_secret, .server_secret = server_secret, .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), - }); + } }); }, - }; - const leftover = d.rest(); - var client: Client = .{ - .read_seq = 0, - .write_seq = 0, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(leftover.len), - .received_close_notify = false, - .application_cipher = app_cipher, - .partially_read_buffer = undefined, - }; - @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - return client; - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - if (ctd.eof()) break; + .tls_1_2 => { + const pv = &p.version.tls_1_2; + try hsd.ensure(12); + if (!std.mem.eql(u8, hsd.array(12), &pv.server_verify_data)) return error.TlsDecryptError; + break :app_cipher @unionInit(tls.ApplicationCipher, @tagName(tag), .{ .tls_1_2 = pv.app_cipher }); + }, + else => unreachable, + }, + }; + const leftover = d.rest(); + var client: Client = .{ + .tls_version = tls_version, + .read_seq = switch (tls_version) { + .tls_1_3 => 0, + .tls_1_2 => read_seq, + else => unreachable, + }, + .write_seq = switch (tls_version) { + .tls_1_3 => 0, + .tls_1_2 => write_seq, + else => unreachable, + }, + .partial_cleartext_idx = 0, + .partial_ciphertext_idx = 0, + .partial_ciphertext_end = @intCast(leftover.len), + .received_close_notify = false, + .application_cipher = app_cipher, + .partially_read_buffer = undefined, + }; + @memcpy(client.partially_read_buffer[0..leftover.len], leftover); + return client; + }, + else => return error.TlsUnexpectedMessage, } + if (ctd.eof()) break; }, - else => { - return error.TlsUnexpectedMessage; - }, + else => return error.TlsUnexpectedMessage, } } } /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. +/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { return writeEnd(c, stream, bytes, false); } @@ -749,7 +875,7 @@ pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !v } /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. +/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. /// If `end` is true, then this function additionally sends a `close_notify` alert, /// which is necessary for the server to distinguish between a properly finished /// TLS session, or a truncation attack. @@ -813,62 +939,127 @@ fn prepareCiphertextRecord( var iovec_end: usize = 0; var bytes_i: usize = 0; switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; - while (true) { - const encrypted_content_len: u16 = @intCast(@min( - @min(bytes.len - bytes_i, tls.max_ciphertext_inner_record_len), - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), - )); - if (encrypted_content_len == 0) return .{ - .iovec_end = iovec_end, - .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, - }; + inline else => |*p| switch (c.tls_version) { + .tls_1_3 => { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; + const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; + while (true) { + const encrypted_content_len: u16 = @min( + bytes.len - bytes_i, + tls.max_ciphertext_inner_record_len, + ciphertext_buf.len -| + (close_notify_alert_reserved + overhead_len + ciphertext_end), + ); + if (encrypted_content_len == 0) return .{ + .iovec_end = iovec_end, + .ciphertext_end = ciphertext_end, + .overhead_len = overhead_len, + }; - @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); - cleartext_buf[encrypted_content_len] = @intFromEnum(inner_content_type); - bytes_i += encrypted_content_len; - const ciphertext_len = encrypted_content_len + 1; - const cleartext = cleartext_buf[0..ciphertext_len]; + @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); + cleartext_buf[encrypted_content_len] = @intFromEnum(inner_content_type); + bytes_i += encrypted_content_len; + const ciphertext_len = encrypted_content_len + 1; + const cleartext = cleartext_buf[0..ciphertext_len]; - const record_start = ciphertext_end; - const ad = ciphertext_buf[ciphertext_end..][0..5]; - ad.* = - [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - int2(ciphertext_len + P.AEAD.tag_length); - ciphertext_end += ad.len; - const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; - ciphertext_end += ciphertext_len; - const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; - ciphertext_end += auth_tag.len; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.client_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.write_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.write_seq))); - break :nonce @as(V, p.client_iv) ^ operand; - }; - c.write_seq += 1; // TODO send key_update on overflow - P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); + const record_start = ciphertext_end; + const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; + ad.* = + [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ + int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + int2(ciphertext_len + P.AEAD.tag_length); + ciphertext_end += ad.len; + const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; + ciphertext_end += ciphertext_len; + const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; + ciphertext_end += auth_tag.len; + const nonce = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.client_iv; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.write_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ std.mem.toBytes(big(c.write_seq)); + break :nonce @as(V, pv.client_iv) ^ operand; + }; + P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key); + c.write_seq += 1; // TODO send key_update on overflow - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; - iovec_end += 1; - } + const record = ciphertext_buf[record_start..ciphertext_end]; + iovecs[iovec_end] = .{ + .base = record.ptr, + .len = record.len, + }; + iovec_end += 1; + } + }, + .tls_1_2 => { + const pv = &p.tls_1_2; + const P = @TypeOf(p.*); + const overhead_len = tls.record_header_len + P.record_iv_length + P.mac_length; + const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; + while (true) { + const message_len: u16 = @min( + bytes.len - bytes_i, + tls.max_ciphertext_inner_record_len, + ciphertext_buf.len -| + (close_notify_alert_reserved + overhead_len + ciphertext_end), + ); + if (message_len == 0) return .{ + .iovec_end = iovec_end, + .ciphertext_end = ciphertext_end, + .overhead_len = overhead_len, + }; + + @memcpy(cleartext_buf[0..message_len], bytes[bytes_i..][0..message_len]); + bytes_i += message_len; + const cleartext = cleartext_buf[0..message_len]; + + const record_start = ciphertext_end; + const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; + ciphertext_end += tls.record_header_len; + record_header.* = [_]u8{@intFromEnum(inner_content_type)} ++ + int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + int2(P.record_iv_length + message_len + P.mac_length); + const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int2(message_len); + const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length]; + ciphertext_end += P.record_iv_length; + const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.client_write_IV ++ pv.client_salt; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.write_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(c.write_seq))); + break :nonce @as(V, pv.client_write_IV ++ pv.client_salt) ^ operand; + }; + record_iv.* = nonce[P.fixed_iv_length..].*; + const ciphertext = ciphertext_buf[ciphertext_end..][0..message_len]; + ciphertext_end += message_len; + const auth_tag = ciphertext_buf[ciphertext_end..][0..P.mac_length]; + ciphertext_end += P.mac_length; + P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_write_key); + c.write_seq += 1; // TODO send key_update on overflow + + const record = ciphertext_buf[record_start..ciphertext_end]; + iovecs[iovec_end] = .{ + .base = record.ptr, + .len = record.len, + }; + iovec_end += 1; + } + }, + else => unreachable, }, } } @@ -990,7 +1181,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove // beginning of the buffer will be used for such purposes. const cleartext_buf_len = free_size - ciphertext_buf_len; - // Recoup `partially_read_buffer space`. This is necessary because it is assumed + // Recoup `partially_read_buffer` space. This is necessary because it is assumed // below that `frag0` is big enough to hold at least one record. limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); c.partial_ciphertext_end -= c.partial_ciphertext_idx; @@ -1105,159 +1296,182 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove in = 0; continue; } - switch (ct) { + const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { + inline else => |*p| switch (c.tls_version) { + .tls_1_3 => { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const ad = frag[in - tls.record_header_len ..][0..tls.record_header_len]; + const ciphertext_len = record_len - P.AEAD.tag_length; + const ciphertext = frag[in..][0..ciphertext_len]; + in += ciphertext_len; + const auth_tag = frag[in..][0..P.AEAD.tag_length].*; + const nonce = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.server_iv; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.read_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); + break :nonce @as(V, pv.server_iv) ^ operand; + }; + const out_buf = vp.peek(); + const cleartext_buf = if (ciphertext.len <= out_buf.len) + out_buf + else + &cleartext_stack_buffer; + const cleartext = cleartext_buf[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch + return error.TlsBadRecordMac; + const msg = mem.trimRight(u8, cleartext, "\x00"); + break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; + }, + .tls_1_2 => { + const pv = &p.tls_1_2; + const P = @TypeOf(p.*); + const message_len: u16 = record_len - P.record_iv_length - P.mac_length; + const ad = std.mem.toBytes(big(c.read_seq)) ++ + frag[in - tls.record_header_len ..][0 .. 1 + 2] ++ + std.mem.toBytes(big(message_len)); + const record_iv = frag[in..][0..P.record_iv_length].*; + in += P.record_iv_length; + const masked_read_seq = c.read_seq & + comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); + const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.server_write_IV ++ record_iv; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ masked_read_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); + break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand; + }; + const ciphertext = frag[in..][0..message_len]; + in += message_len; + const auth_tag = frag[in..][0..P.mac_length].*; + in += P.mac_length; + const out_buf = vp.peek(); + const cleartext_buf = if (message_len <= out_buf.len) + out_buf + else + &cleartext_stack_buffer; + const cleartext = cleartext_buf[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch + return error.TlsBadRecordMac; + break :cleartext .{ cleartext, ct }; + }, + else => unreachable, + }, + }; + c.read_seq = try std.math.add(u64, c.read_seq, 1); + switch (inner_ct) { .alert => { - if (in + 2 > frag.len) return error.TlsDecodeError; - const level: tls.AlertLevel = @enumFromInt(frag[in]); - const desc: tls.AlertDescription = @enumFromInt(frag[in + 1]); + if (cleartext.len != 2) return error.TlsDecodeError; + const level: tls.AlertLevel = @enumFromInt(cleartext[0]); + const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); + if (desc == .close_notify) { + c.received_close_notify = true; + c.partial_ciphertext_end = c.partial_ciphertext_idx; + return vp.total; + } _ = level; try desc.toError(); // TODO: handle server-side closures return error.TlsUnexpectedMessage; }, - .application_data => { - const cleartext = switch (c.application_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const ad = frag[in - 5 ..][0..5]; - const ciphertext_len = record_len - P.AEAD.tag_length; - const ciphertext = frag[in..][0..ciphertext_len]; - in += ciphertext_len; - const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.server_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.read_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.read_seq))); - break :nonce @as(V, p.server_iv) ^ operand; - }; - const out_buf = vp.peek(); - const cleartext_buf = if (ciphertext.len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch - return error.TlsBadRecordMac; - break :c mem.trimRight(u8, cleartext, "\x00"); - }, - }; - - c.read_seq = try std.math.add(u64, c.read_seq, 1); - - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - switch (inner_ct) { - .alert => { - const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - if (desc == .close_notify) { - c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len - 1) - return error.TlsBadLength; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .new_session_ticket => { - // This client implementation ignores new session tickets. + .handshake => { + var ct_i: usize = 0; + while (true) { + const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); + ct_i += 1; + const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); + ct_i += 3; + const next_handshake_i = ct_i + handshake_len; + if (next_handshake_i > cleartext.len) + return error.TlsBadLength; + const handshake = cleartext[ct_i..next_handshake_i]; + switch (handshake_type) { + .new_session_ticket => { + // This client implementation ignores new session tickets. + }, + .key_update => { + switch (c.application_cipher) { + inline else => |*p| { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); + pv.server_secret = server_secret; + pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); }, - .key_update => { + } + c.read_seq = 0; + + switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { + .update_requested => { switch (c.application_cipher) { inline else => |*p| { + const pv = &p.tls_1_3; const P = @TypeOf(p.*); - const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); - p.server_secret = server_secret; - p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); + pv.client_secret = client_secret; + pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); }, } - c.read_seq = 0; - - switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { - .update_requested => { - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); - p.client_secret = client_secret; - p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.write_seq = 0; - }, - .update_not_requested => {}, - _ => return error.TlsIllegalParameter, - } - }, - else => { - return error.TlsUnexpectedMessage; + c.write_seq = 0; }, + .update_not_requested => {}, + _ => return error.TlsIllegalParameter, } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len - 1) break; - } - }, - .application_data => { - // Determine whether the output buffer or a stack - // buffer was used for storing the cleartext. - if (cleartext.ptr == &cleartext_stack_buffer) { - // Stack buffer was used, so we must copy to the output buffer. - const msg = cleartext[0 .. cleartext.len - 1]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // We have already run out of room in iovecs. Continue - // appending to `partially_read_buffer`. - @memcpy( - c.partially_read_buffer[c.partial_ciphertext_idx..][0..msg.len], - msg, - ); - c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + msg.len); - } else { - const amt = vp.put(msg); - if (amt < msg.len) { - const rest = msg[amt..]; - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = @intCast(rest.len); - @memcpy(c.partially_read_buffer[0..rest.len], rest); - } - } - } else { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - vp.next(cleartext.len - 1); - } - }, - else => { - return error.TlsUnexpectedMessage; - }, + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + ct_i = next_handshake_i; + if (ct_i >= cleartext.len) break; } }, - else => { - return error.TlsUnexpectedMessage; + .application_data => { + // Determine whether the output buffer or a stack + // buffer was used for storing the cleartext. + if (cleartext.ptr == &cleartext_stack_buffer) { + // Stack buffer was used, so we must copy to the output buffer. + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // We have already run out of room in iovecs. Continue + // appending to `partially_read_buffer`. + @memcpy( + c.partially_read_buffer[c.partial_ciphertext_idx..][0..cleartext.len], + cleartext, + ); + c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + cleartext.len); + } else { + const amt = vp.put(cleartext); + if (amt < cleartext.len) { + const rest = cleartext[amt..]; + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = @intCast(rest.len); + @memcpy(c.partially_read_buffer[0..rest.len], rest); + } + } + } else { + // Output buffer was used directly which means no + // memory copying needs to occur, and we can move + // on to the next ciphertext record. + vp.next(cleartext.len); + } }, + else => return error.TlsUnexpectedMessage, } in = end; } @@ -1326,6 +1540,74 @@ inline fn big(x: anytype) @TypeOf(x) { }; } +const KeyShare = struct { + x25519_kp: crypto.dh.X25519.KeyPair, + secp256r1_kp: crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair, + ml_kem768_kp: crypto.kem.ml_kem.MLKem768.KeyPair, + sk_buf: [sk_max_len]u8, + sk_len: std.math.IntFittingRange(0, sk_max_len), + + const sk_max_len = @max( + crypto.dh.X25519.shared_length + crypto.kem.ml_kem.MLKem768.shared_length, + crypto.dh.X25519.shared_length, + crypto.ecc.P256.scalar.encoded_length, + ); + + fn init(seed: [64]u8) error{IdentityElement}!KeyShare { + return .{ + .x25519_kp = try .create(seed[0..32].*), + .secp256r1_kp = try .create(seed[32..64].*), + .ml_kem768_kp = try .create(null), + .sk_buf = undefined, + .sk_len = 0, + }; + } + + fn exchange( + ks: *KeyShare, + named_group: tls.NamedGroup, + server_pub_key: []const u8, + ) error{ TlsIllegalParameter, TlsDecryptFailure }!void { + switch (named_group) { + .x25519_ml_kem768 => { + const xksl = crypto.dh.X25519.public_length; + const hksl = xksl + crypto.kem.ml_kem.MLKem768.ciphertext_length; + if (server_pub_key.len != hksl) return error.TlsIllegalParameter; + + const xsk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..xksl].*) catch + return error.TlsDecryptFailure; + const hsk = ks.ml_kem768_kp.secret_key.decaps(server_pub_key[xksl..hksl]) catch + return error.TlsDecryptFailure; + @memcpy(ks.sk_buf[0..xsk.len], &xsk); + @memcpy(ks.sk_buf[xsk.len..][0..hsk.len], &hsk); + ks.sk_len = xsk.len + hsk.len; + }, + .x25519 => { + const ksl = crypto.dh.X25519.public_length; + if (server_pub_key.len != ksl) return error.TlsIllegalParameter; + const sk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..ksl].*) catch + return error.TlsDecryptFailure; + @memcpy(ks.sk_buf[0..sk.len], &sk); + ks.sk_len = sk.len; + }, + .secp256r1 => { + const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; + const pk = PublicKey.fromSec1(server_pub_key) catch return error.TlsDecryptFailure; + const mul = pk.p.mulPublic(ks.secp256r1_kp.secret_key.bytes, .big) catch + return error.TlsDecryptFailure; + const sk = mul.affineCoordinates().x.toBytes(.big); + @memcpy(ks.sk_buf[0..sk.len], &sk); + ks.sk_len = sk.len; + }, + else => return error.TlsIllegalParameter, + } + } + + fn getSharedSecret(ks: *const KeyShare) ?[]const u8 { + return if (ks.sk_len > 0) ks.sk_buf[0..ks.sk_len] else null; + } +}; + fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { return switch (scheme) { .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, @@ -1334,11 +1616,20 @@ fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { }; } -fn SchemeHash(comptime scheme: tls.SignatureScheme) type { +fn SchemeRsa(comptime scheme: tls.SignatureScheme) type { return switch (scheme) { - .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, - .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, - .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + .rsa_pkcs1_sha1, + => Certificate.rsa.PKCS1v1_5Signature, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .rsa_pss_pss_sha256, + .rsa_pss_pss_sha384, + .rsa_pss_pss_sha512, + => Certificate.rsa.PSSSignature, else => @compileError("bad scheme"), }; } @@ -1350,6 +1641,142 @@ fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type { }; } +fn SchemeHash(comptime scheme: tls.SignatureScheme) type { + return switch (scheme) { + .rsa_pkcs1_sha256, + .ecdsa_secp256r1_sha256, + .rsa_pss_rsae_sha256, + .rsa_pss_pss_sha256, + => crypto.hash.sha2.Sha256, + .rsa_pkcs1_sha384, + .ecdsa_secp384r1_sha384, + .rsa_pss_rsae_sha384, + .rsa_pss_pss_sha384, + => crypto.hash.sha2.Sha384, + .rsa_pkcs1_sha512, + .ecdsa_secp521r1_sha512, + .rsa_pss_rsae_sha512, + .rsa_pss_pss_sha512, + => crypto.hash.sha2.Sha512, + .rsa_pkcs1_sha1, + .ecdsa_sha1, + => crypto.hash.Sha1, + else => @compileError("bad scheme"), + }; +} + +const CertificatePublicKey = struct { + algo: Certificate.AlgorithmCategory, + buf: [600]u8, + len: u16, + + fn init( + cert_pub_key: *CertificatePublicKey, + algo: Certificate.AlgorithmCategory, + pub_key: []const u8, + ) error{CertificatePublicKeyInvalid}!void { + if (pub_key.len > cert_pub_key.buf.len) return error.CertificatePublicKeyInvalid; + cert_pub_key.algo = algo; + @memcpy(cert_pub_key.buf[0..pub_key.len], pub_key); + cert_pub_key.len = @intCast(pub_key.len); + } + + const VerifyError = error{ TlsDecodeError, TlsBadSignatureScheme, InvalidEncoding } || + // ecdsa + crypto.errors.EncodingError || + crypto.errors.NotSquareError || + crypto.errors.NonCanonicalError || + SchemeEcdsa(.ecdsa_secp256r1_sha256).Signature.VerifyError || + SchemeEcdsa(.ecdsa_secp384r1_sha384).Signature.VerifyError || + // rsa + error{TlsBadRsaSignatureBitCount} || + Certificate.rsa.PublicKey.ParseDerError || + Certificate.rsa.PublicKey.FromBytesError || + Certificate.rsa.PSSSignature.VerifyError || + Certificate.rsa.PKCS1v1_5Signature.VerifyError || + // eddsa + SchemeEddsa(.ed25519).Signature.VerifyError; + + fn verifySignature( + cert_pub_key: *const CertificatePublicKey, + sigd: *tls.Decoder, + msg: []const []const u8, + ) VerifyError!void { + const pub_key = cert_pub_key.buf[0..cert_pub_key.len]; + + try sigd.ensure(2 + 2); + const scheme = sigd.decode(tls.SignatureScheme); + const sig_len = sigd.decode(u16); + try sigd.ensure(sig_len); + const encoded_sig = sigd.slice(sig_len); + + if (cert_pub_key.algo != @as(Certificate.AlgorithmCategory, switch (scheme) { + .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => .X9_62_id_ecPublicKey, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .rsa_pkcs1_sha1, + => .rsaEncryption, + .rsa_pss_pss_sha256, + .rsa_pss_pss_sha384, + .rsa_pss_pss_sha512, + => .rsassa_pss, + else => return error.TlsBadSignatureScheme, + })) return error.TlsBadSignatureScheme; + + switch (scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + const Ecdsa = SchemeEcdsa(comptime_scheme); + const sig = try Ecdsa.Signature.fromDer(encoded_sig); + const key = try Ecdsa.PublicKey.fromSec1(pub_key); + try sig.concatVerify(msg, key); + }, + inline .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .rsa_pss_pss_sha256, + .rsa_pss_pss_sha384, + .rsa_pss_pss_sha512, + .rsa_pkcs1_sha1, + => |comptime_scheme| { + const RsaSignature = SchemeRsa(comptime_scheme); + const Hash = SchemeHash(comptime_scheme); + const PublicKey = Certificate.rsa.PublicKey; + const components = try PublicKey.parseDer(pub_key); + const exponent = components.exponent; + const modulus = components.modulus; + switch (modulus.len) { + inline 128, 256, 512 => |modulus_len| { + const key: PublicKey = try .fromBytes(exponent, modulus); + const sig = RsaSignature.fromBytes(modulus_len, encoded_sig); + try RsaSignature.concatVerify(modulus_len, sig, msg, key, Hash); + }, + else => return error.TlsBadRsaSignatureBitCount, + } + }, + inline .ed25519 => |comptime_scheme| { + const Eddsa = SchemeEddsa(comptime_scheme); + if (encoded_sig.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; + const sig = Eddsa.Signature.fromBytes(encoded_sig[0..Eddsa.Signature.encoded_length].*); + if (pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; + const key = try Eddsa.PublicKey.fromBytes(pub_key[0..Eddsa.PublicKey.encoded_length].*); + try sig.concatVerify(msg, key); + }, + else => unreachable, + } + } +}; + /// Abstraction for sending multiple byte buffers to a slice of iovecs. const VecPut = struct { iovecs: []const std.posix.iovec, @@ -1451,16 +1878,22 @@ const cipher_suites = if (crypto.core.aes.has_hardware_support) .AEGIS_128L_SHA256, .AEGIS_256_SHA512, .AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, .AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, .CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, }) else enum_array(tls.CipherSuite, &.{ .CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, .AEGIS_128L_SHA256, .AEGIS_256_SHA512, .AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, .AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, }); test { diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 78511f435d..c56d3a24a1 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -172,7 +172,13 @@ pub const HeadersParser = struct { const data_avail = r.next_chunk_length; if (skip) { - try conn.fill(); + conn.fill() catch |err| switch (err) { + error.EndOfStream => { + r.done = true; + return 0; + }, + else => |e| return e, + }; const nread = @min(conn.peek().len, data_avail); conn.drop(@intCast(nread)); @@ -196,7 +202,13 @@ pub const HeadersParser = struct { } }, .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { - try conn.fill(); + conn.fill() catch |err| switch (err) { + error.EndOfStream => { + r.done = true; + return 0; + }, + else => |e| return e, + }; const i = r.findChunkedLen(conn.peek()); conn.drop(@intCast(i)); @@ -226,7 +238,13 @@ pub const HeadersParser = struct { const out_avail = buffer.len - out_index; if (skip) { - try conn.fill(); + conn.fill() catch |err| switch (err) { + error.EndOfStream => { + r.done = true; + return 0; + }, + else => |e| return e, + }; const nread = @min(conn.peek().len, data_avail); conn.drop(@intCast(nread)); From e184b15a6639f28e47d4a43d59136726d686b3b1 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Fri, 1 Nov 2024 00:11:44 -0400 Subject: [PATCH 02/14] std.crypto.tls: fix fetching https://nginx.org Note that the removed `error.TlsIllegalParameter` case is still caught below when it is compared to a fixed-length string, but after checking the proper protocol version requirement first. --- lib/std/crypto/tls/Client.zig | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index c69c6ee936..c220c890f8 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -257,7 +257,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; const length = ptd.decode(u24); var hsd = try ptd.sub(length); - try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1); + try hsd.ensure(2 + 32 + 1); const legacy_version = hsd.decode(u16); @memcpy(&server_hello_rand, hsd.array(32)); if (mem.eql(u8, &server_hello_rand, &tls.hello_retry_request_sequence)) { @@ -266,8 +266,8 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In return error.TlsUnexpectedMessage; } const legacy_session_id_echo_len = hsd.decode(u8); - if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; - const legacy_session_id_echo = hsd.array(32); + try hsd.ensure(legacy_session_id_echo_len + 2 + 1); + const legacy_session_id_echo = hsd.slice(legacy_session_id_echo_len); cipher_suite_tag = hsd.decode(tls.CipherSuite); hsd.skip(1); // legacy_compression_method var supported_version: ?u16 = null; From 4466f145d67b7e343309347b4d1f59f10a3590af Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Fri, 1 Nov 2024 01:48:25 -0400 Subject: [PATCH 03/14] std.crypto.tls: support more key share params This condition is already checked less restrictively in `KeyShare.exchange`. --- lib/std/crypto/tls/Client.zig | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index c220c890f8..dc15c7d813 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -636,7 +636,6 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const curve_type = hsd.decode(u8); if (curve_type != 0x03) return error.TlsIllegalParameter; // named_curve const named_group = hsd.decode(tls.NamedGroup); - if (named_group != .secp256r1) return error.TlsIllegalParameter; const key_size = hsd.decode(u8); try hsd.ensure(key_size); const server_pub_key = hsd.slice(key_size); From 7f20c78c95d54e7fb0958693a9df2dee40dd99d6 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Fri, 1 Nov 2024 16:29:35 -0400 Subject: [PATCH 04/14] std.crypto: delete new functions that are only used once --- lib/std/crypto/25519/ed25519.zig | 11 ++--------- lib/std/crypto/ecdsa.zig | 11 ++--------- lib/std/crypto/tls/Client.zig | 8 ++++++-- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/lib/std/crypto/25519/ed25519.zig b/lib/std/crypto/25519/ed25519.zig index 3620cfc4ba..3e5fb8bb3c 100644 --- a/lib/std/crypto/25519/ed25519.zig +++ b/lib/std/crypto/25519/ed25519.zig @@ -229,15 +229,8 @@ pub const Ed25519 = struct { /// Return IdentityElement or NonCanonical if the public key or signature are not in the expected range, /// or SignatureVerificationError if the signature is invalid for the given message and key. pub fn verify(sig: Signature, msg: []const u8, public_key: PublicKey) VerifyError!void { - try sig.concatVerify(&.{msg}, public_key); - } - - /// Verify the signature against a concatenated message and public key. - /// Return IdentityElement or NonCanonical if the public key or signature are not in the expected range, - /// or SignatureVerificationError if the signature is invalid for the given message and key. - pub fn concatVerify(sig: Signature, msg: []const []const u8, public_key: PublicKey) VerifyError!void { - var st = try Verifier.init(sig, public_key); - for (msg) |part| st.update(part); + var st = try sig.verifier(public_key); + st.update(msg); try st.verify(); } }; diff --git a/lib/std/crypto/ecdsa.zig b/lib/std/crypto/ecdsa.zig index a015178f3d..649c967218 100644 --- a/lib/std/crypto/ecdsa.zig +++ b/lib/std/crypto/ecdsa.zig @@ -101,15 +101,8 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { /// Return IdentityElement or NonCanonical if the public key or signature are not in the expected range, /// or SignatureVerificationError if the signature is invalid for the given message and key. pub fn verify(sig: Signature, msg: []const u8, public_key: PublicKey) VerifyError!void { - try sig.concatVerify(&.{msg}, public_key); - } - - /// Verify the signature against a concatenated message and public key. - /// Return IdentityElement or NonCanonical if the public key or signature are not in the expected range, - /// or SignatureVerificationError if the signature is invalid for the given message and key. - pub fn concatVerify(sig: Signature, msg: []const []const u8, public_key: PublicKey) VerifyError!void { - var st = try Verifier.init(sig, public_key); - for (msg) |part| st.update(part); + var st = try sig.verifier(public_key); + st.update(msg); try st.verify(); } diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index dc15c7d813..5b1d0a9bf6 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1735,7 +1735,9 @@ const CertificatePublicKey = struct { const Ecdsa = SchemeEcdsa(comptime_scheme); const sig = try Ecdsa.Signature.fromDer(encoded_sig); const key = try Ecdsa.PublicKey.fromSec1(pub_key); - try sig.concatVerify(msg, key); + var ver = try sig.verifier(key); + for (msg) |part| ver.update(part); + try ver.verify(); }, inline .rsa_pkcs1_sha256, .rsa_pkcs1_sha384, @@ -1769,7 +1771,9 @@ const CertificatePublicKey = struct { const sig = Eddsa.Signature.fromBytes(encoded_sig[0..Eddsa.Signature.encoded_length].*); if (pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; const key = try Eddsa.PublicKey.fromBytes(pub_key[0..Eddsa.PublicKey.encoded_length].*); - try sig.concatVerify(msg, key); + var ver = try sig.verifier(key); + for (msg) |part| ver.update(part); + try ver.verify(); }, else => unreachable, } From 7afb2777250a251a065b8a970eb7f14e2d5b5ce2 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Sat, 2 Nov 2024 02:45:12 -0400 Subject: [PATCH 05/14] std.crypto.tls: fix x25519_ml_kem768 key share This is mostly nfc cleanup as I was bisecting the client hello to find the problematic part, and the only bug fix ended up being key_share.x25519_kp.public_key ++ key_share.ml_kem768_kp.public_key.toBytes() to key_share.ml_kem768_kp.public_key.toBytes() ++ key_share.x25519_kp.public_key) and the same swap in `KeyShare.exchange` as per some random blog that says "a hybrid keyshare, constructed by concatenating the public KEM key with the public X25519 key". I also note that based on the same blog post, there was a draft version of this method that indeed had these values swapped, and that used to be supported by this code, but it was not properly fixed up when this code was updated from the draft spec. Closes #21747 --- lib/std/crypto/tls.zig | 67 +++++++---- lib/std/crypto/tls/Client.zig | 206 +++++++++++++++------------------- 2 files changed, 135 insertions(+), 138 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 7732f3b74e..6479c77d75 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -291,6 +291,12 @@ pub const NamedGroup = enum(u16) { _, }; +pub const PskKeyExchangeMode = enum(u8) { + psk_ke = 0, + psk_dhe_ke = 1, + _, +}; + pub const CipherSuite = enum(u16) { RSA_WITH_AES_128_CBC_SHA = 0x002F, DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033, @@ -407,6 +413,11 @@ pub const CipherSuite = enum(u16) { } }; +pub const CompressionMethod = enum(u8) { + null = 0, + _, +}; + pub const CertificateType = enum(u8) { X509 = 0, RawPublicKey = 2, @@ -419,6 +430,11 @@ pub const KeyUpdateRequest = enum(u8) { _, }; +pub const ChangeCipherSpecType = enum(u8) { + change_cipher_spec = 1, + _, +}; + pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type, comptime explicit_iv_length: comptime_int) type { return struct { pub const A = ApplicationCipherT(AeadType, HashType, explicit_iv_length); @@ -560,34 +576,38 @@ pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) return result; } -pub inline fn extension(comptime et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { - return int2(@intFromEnum(et)) ++ array(1, bytes); +pub inline fn extension(et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { + return int(u16, @intFromEnum(et)) ++ array(u16, u8, bytes); } -pub inline fn array(comptime elem_size: comptime_int, bytes: anytype) [2 + bytes.len]u8 { - comptime assert(bytes.len % elem_size == 0); - return int2(bytes.len) ++ bytes; -} - -pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeOf(E) * tags.len]u8 { - assert(@sizeOf(E) == 2); - var result: [tags.len * 2]u8 = undefined; - for (tags, 0..) |elem, i| { - result[i * 2] = @as(u8, @truncate(@intFromEnum(elem) >> 8)); - result[i * 2 + 1] = @as(u8, @truncate(@intFromEnum(elem))); +pub inline fn array( + comptime Len: type, + comptime Elem: type, + elems: anytype, +) [@divExact(@bitSizeOf(Len), 8) + @divExact(@bitSizeOf(Elem), 8) * elems.len]u8 { + const len_size = @divExact(@bitSizeOf(Len), 8); + const elem_size = @divExact(@bitSizeOf(Elem), 8); + var arr: [len_size + elem_size * elems.len]u8 = undefined; + std.mem.writeInt(Len, arr[0..len_size], @intCast(elem_size * elems.len), .big); + const ElemInt = @Type(.{ .int = .{ .signedness = .unsigned, .bits = @bitSizeOf(Elem) } }); + for (0.., @as([elems.len]Elem, elems)) |index, elem| { + std.mem.writeInt( + ElemInt, + arr[len_size + elem_size * index ..][0..elem_size], + switch (@typeInfo(Elem)) { + .int => @as(Elem, elem), + .@"enum" => @intFromEnum(@as(Elem, elem)), + else => @bitCast(@as(Elem, elem)), + }, + .big, + ); } - return array(2, result); -} - -pub inline fn int2(int: u16) [2]u8 { - var arr: [2]u8 = undefined; - std.mem.writeInt(u16, &arr, int, .big); return arr; } -pub inline fn int3(int: u24) [3]u8 { - var arr: [3]u8 = undefined; - std.mem.writeInt(u24, &arr, int, .big); +pub inline fn int(comptime Int: type, val: Int) [@divExact(@bitSizeOf(Int), 8)]u8 { + var arr: [@divExact(@bitSizeOf(Int), 8)]u8 = undefined; + std.mem.writeInt(Int, &arr, val, .big); return arr; } @@ -670,9 +690,8 @@ pub const Decoder = struct { else => @compileError("unsupported int type: " ++ @typeName(T)), }, .@"enum" => |info| { - const int = d.decode(info.tag_type); if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); - return @as(T, @enumFromInt(int)); + return @enumFromInt(d.decode(info.tag_type)); }, else => @compileError("unsupported type: " ++ @typeName(T)), } diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 5b1d0a9bf6..a8624fd03f 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -10,10 +10,8 @@ const Certificate = std.crypto.Certificate; const max_ciphertext_len = tls.max_ciphertext_len; const hmacExpandLabel = tls.hmacExpandLabel; const hkdfExpandLabel = tls.hkdfExpandLabel; -const int2 = tls.int2; -const int3 = tls.int3; +const int = tls.int; const array = tls.array; -const enum_array = tls.enum_array; tls_version: tls.ProtocolVersion, read_seq: u64, @@ -156,70 +154,62 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In error.IdentityElement => return error.InsufficientEntropy, }; - const extensions_payload = - tls.extension(.supported_versions, [_]u8{2 + 2} ++ // byte length of supported versions - int2(@intFromEnum(tls.ProtocolVersion.tls_1_3)) ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2))) ++ - tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ + const extensions_payload = tls.extension(.supported_versions, array(u8, tls.ProtocolVersion, .{ + .tls_1_3, + .tls_1_2, + })) ++ tls.extension(.signature_algorithms, array(u16, tls.SignatureScheme, .{ .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384, .rsa_pss_rsae_sha256, .rsa_pss_rsae_sha384, .rsa_pss_rsae_sha512, .ed25519, - })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ + })) ++ tls.extension(.supported_groups, array(u16, tls.NamedGroup, .{ .x25519_ml_kem768, .secp256r1, .x25519, - })) ++ tls.extension( - .key_share, - array(1, int2(@intFromEnum(tls.NamedGroup.x25519)) ++ - array(1, key_share.x25519_kp.public_key) ++ - int2(@intFromEnum(tls.NamedGroup.secp256r1)) ++ - array(1, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++ - int2(@intFromEnum(tls.NamedGroup.x25519_ml_kem768)) ++ - array(1, key_share.x25519_kp.public_key ++ key_share.ml_kem768_kp.public_key.toBytes())), - ) ++ - int2(@intFromEnum(tls.ExtensionType.server_name)) ++ - int2(host_len + 5) ++ // byte length of this extension payload - int2(host_len + 3) ++ // server_name_list byte count - [1]u8{0x00} ++ // name_type - int2(host_len); + })) ++ tls.extension(.psk_key_exchange_modes, array(u8, tls.PskKeyExchangeMode, .{ + .psk_dhe_ke, + })) ++ tls.extension(.key_share, array( + u16, + u8, + int(u16, @intFromEnum(tls.NamedGroup.x25519_ml_kem768)) ++ + array(u16, u8, key_share.ml_kem768_kp.public_key.toBytes() ++ key_share.x25519_kp.public_key) ++ + int(u16, @intFromEnum(tls.NamedGroup.secp256r1)) ++ + array(u16, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++ + int(u16, @intFromEnum(tls.NamedGroup.x25519)) ++ + array(u16, u8, key_share.x25519_kp.public_key), + )) ++ int(u16, @intFromEnum(tls.ExtensionType.server_name)) ++ + int(u16, 2 + 1 + 2 + host_len) ++ // byte length of this extension payload + int(u16, 1 + 2 + host_len) ++ // server_name_list byte count + .{0x00} ++ // name_type + int(u16, host_len); const extensions_header = - int2(@intCast(extensions_payload.len + host_len)) ++ + int(u16, @intCast(extensions_payload.len + host_len)) ++ extensions_payload; - const legacy_compression_methods = 0x0100; - const client_hello = - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ client_hello_rand ++ [1]u8{32} ++ legacy_session_id ++ cipher_suites ++ - int2(legacy_compression_methods) ++ + array(u8, tls.CompressionMethod, .{.null}) ++ extensions_header; - const out_handshake = - [_]u8{@intFromEnum(tls.HandshakeType.client_hello)} ++ - int3(@intCast(client_hello.len + host_len)) ++ + const out_handshake = .{@intFromEnum(tls.HandshakeType.client_hello)} ++ + int(u24, @intCast(client_hello.len + host_len)) ++ client_hello; - const cleartext_header = [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_0)) ++ // legacy_record_version - int2(@intCast(out_handshake.len + host_len)) ++ + const cleartext_header = .{@intFromEnum(tls.ContentType.handshake)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_0)) ++ + int(u16, @intCast(out_handshake.len + host_len)) ++ out_handshake; { var iovecs = [_]std.posix.iovec_const{ - .{ - .base = &cleartext_header, - .len = cleartext_header.len, - }, - .{ - .base = host.ptr, - .len = host.len, - }, + .{ .base = &cleartext_header, .len = cleartext_header.len }, + .{ .base = host.ptr, .len = host.len }, }; try stream.writevAll(&iovecs); } @@ -526,7 +516,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In }, .change_cipher_spec => { try ctd.ensure(1); - if (ctd.decode(u8) != 0x01) return error.TlsIllegalParameter; + if (ctd.decode(tls.ChangeCipherSpecType) != .change_cipher_spec) return error.TlsIllegalParameter; cipher_state = pending_cipher_state; }, .handshake => while (true) { @@ -648,20 +638,13 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In if (handshake_state != .server_hello_done) return error.TlsUnexpectedMessage; handshake_state = .finished; - const client_key_exchange_msg = - [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ // record content type - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version - int2(0x46) ++ // record length - .{@intFromEnum(tls.HandshakeType.client_key_exchange)} ++ // handshake type - int3(0x42) ++ // params length - .{0x41} ++ // pubkey length - key_share.secp256r1_kp.public_key.toUncompressedSec1(); - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = - [_]u8{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ // record content type - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version - int2(1) ++ // record length - .{0x01}; + const client_key_exchange_msg = .{@intFromEnum(tls.ContentType.handshake)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, u8, .{@intFromEnum(tls.HandshakeType.client_key_exchange)} ++ + array(u24, u8, array(u8, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1()))); + const client_change_cipher_spec_msg = .{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, tls.ChangeCipherSpecType, .{.change_cipher_spec}); const pre_master_secret = key_share.getSharedSecret().?; switch (handshake_cipher) { inline else => |*p| { @@ -680,10 +663,13 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In @sizeOf(P.Tls_1_2), ); const verify_data_len = 12; - const client_verify_cleartext = - [_]u8{@intFromEnum(tls.HandshakeType.finished)} ++ // handshake type - int3(verify_data_len) ++ // verify data length - hmacExpandLabel(P.Hmac, &master_secret, &.{ "client finished", &p.transcript_hash.peek() }, verify_data_len); + const client_verify_cleartext = .{@intFromEnum(tls.HandshakeType.finished)} ++ + array(u24, u8, hmacExpandLabel( + P.Hmac, + &master_secret, + &.{ "client finished", &p.transcript_hash.peek() }, + verify_data_len, + )); p.transcript_hash.update(&client_verify_cleartext); p.version = .{ .tls_1_2 = .{ .server_verify_data = hmacExpandLabel( @@ -709,25 +695,23 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const operand: V = pad ++ @as([8]u8, @bitCast(big(write_seq))); break :nonce @as(V, pv.app_cipher.client_write_IV ++ pv.app_cipher.client_salt) ^ operand; }; - var client_verify_msg = [_]u8{@intFromEnum(tls.ContentType.handshake)} ++ // record content type - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version - int2(P.record_iv_length + client_verify_cleartext.len + P.mac_length) ++ // record length - nonce[P.fixed_iv_length..].* ++ - @as([client_verify_cleartext.len + P.mac_length]u8, undefined); + var client_verify_msg = .{@intFromEnum(tls.ContentType.handshake)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, u8, nonce[P.fixed_iv_length..].* ++ + @as([client_verify_cleartext.len + P.mac_length]u8, undefined)); P.AEAD.encrypt( client_verify_msg[client_verify_msg.len - P.mac_length - client_verify_cleartext.len ..][0..client_verify_cleartext.len], client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length], &client_verify_cleartext, - std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int2(client_verify_cleartext.len), + std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len), nonce, pv.app_cipher.client_write_key, ); const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg; - var all_msgs_vec = [_]std.posix.iovec_const{.{ - .base = &all_msgs, - .len = all_msgs.len, - }}; + var all_msgs_vec = [_]std.posix.iovec_const{ + .{ .base = &all_msgs, .len = all_msgs.len }, + }; try stream.writevAll(&all_msgs_vec); }, } @@ -755,11 +739,9 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In if (cipher_state == .cleartext) return error.TlsUnexpectedMessage; if (handshake_state != .finished) return error.TlsUnexpectedMessage; // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = - [_]u8{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version - int2(1) ++ // length - .{0x01}; + const client_change_cipher_spec_msg = .{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, tls.ChangeCipherSpecType, .{.change_cipher_spec}); const app_cipher = app_cipher: switch (handshake_cipher) { inline else => |*p, tag| switch (tls_version) { .tls_1_3 => { @@ -771,17 +753,15 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In if (!mem.eql(u8, &expected_server_verify_data, hsd.buf)) return error.TlsDecryptError; const handshake_hash = p.transcript_hash.finalResult(); const verify_data = tls.hmac(P.Hmac, &handshake_hash, pv.client_finished_key); - const out_cleartext = [_]u8{ - @intFromEnum(tls.HandshakeType.finished), - 0, 0, verify_data.len, // length - } ++ verify_data ++ [1]u8{@intFromEnum(tls.ContentType.handshake)}; + const out_cleartext = .{@intFromEnum(tls.HandshakeType.finished)} ++ + array(u24, u8, verify_data) ++ + .{@intFromEnum(tls.ContentType.handshake)}; const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - var finished_msg = [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ // legacy protocol version - int2(wrapped_len) ++ // byte length of encrypted record - @as([wrapped_len]u8, undefined); + var finished_msg = .{@intFromEnum(tls.ContentType.application_data)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, u8, @as([wrapped_len]u8, undefined)); const ad = finished_msg[0..tls.record_header_len]; const ciphertext = finished_msg[tls.record_header_len..][0..out_cleartext.len]; @@ -790,10 +770,9 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key); const all_msgs = client_change_cipher_spec_msg ++ finished_msg; - var all_msgs_vec = [_]std.posix.iovec_const{.{ - .base = &all_msgs, - .len = all_msgs.len, - }}; + var all_msgs_vec = [_]std.posix.iovec_const{ + .{ .base = &all_msgs, .len = all_msgs.len }, + }; try stream.writevAll(&all_msgs_vec); const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); @@ -965,10 +944,9 @@ fn prepareCiphertextRecord( const record_start = ciphertext_end; const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; - ad.* = - [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - int2(ciphertext_len + P.AEAD.tag_length); + ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + int(u16, ciphertext_len + P.AEAD.tag_length); ciphertext_end += ad.len; const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; ciphertext_end += ciphertext_len; @@ -1023,10 +1001,10 @@ fn prepareCiphertextRecord( const record_start = ciphertext_end; const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; ciphertext_end += tls.record_header_len; - record_header.* = [_]u8{@intFromEnum(inner_content_type)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - int2(P.record_iv_length + message_len + P.mac_length); - const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int2(message_len); + record_header.* = .{@intFromEnum(inner_content_type)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + int(u16, P.record_iv_length + message_len + P.mac_length); + const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len); const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length]; ciphertext_end += P.record_iv_length; const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and @@ -1569,25 +1547,17 @@ const KeyShare = struct { ) error{ TlsIllegalParameter, TlsDecryptFailure }!void { switch (named_group) { .x25519_ml_kem768 => { - const xksl = crypto.dh.X25519.public_length; - const hksl = xksl + crypto.kem.ml_kem.MLKem768.ciphertext_length; - if (server_pub_key.len != hksl) return error.TlsIllegalParameter; + const hksl = crypto.kem.ml_kem.MLKem768.ciphertext_length; + const xksl = hksl + crypto.dh.X25519.public_length; + if (server_pub_key.len != xksl) return error.TlsIllegalParameter; - const xsk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..xksl].*) catch + const hsk = ks.ml_kem768_kp.secret_key.decaps(server_pub_key[0..hksl]) catch return error.TlsDecryptFailure; - const hsk = ks.ml_kem768_kp.secret_key.decaps(server_pub_key[xksl..hksl]) catch + const xsk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[hksl..xksl].*) catch return error.TlsDecryptFailure; - @memcpy(ks.sk_buf[0..xsk.len], &xsk); - @memcpy(ks.sk_buf[xsk.len..][0..hsk.len], &hsk); - ks.sk_len = xsk.len + hsk.len; - }, - .x25519 => { - const ksl = crypto.dh.X25519.public_length; - if (server_pub_key.len != ksl) return error.TlsIllegalParameter; - const sk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..ksl].*) catch - return error.TlsDecryptFailure; - @memcpy(ks.sk_buf[0..sk.len], &sk); - ks.sk_len = sk.len; + @memcpy(ks.sk_buf[0..hsk.len], &hsk); + @memcpy(ks.sk_buf[hsk.len..][0..xsk.len], &xsk); + ks.sk_len = hsk.len + xsk.len; }, .secp256r1 => { const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; @@ -1598,6 +1568,14 @@ const KeyShare = struct { @memcpy(ks.sk_buf[0..sk.len], &sk); ks.sk_len = sk.len; }, + .x25519 => { + const ksl = crypto.dh.X25519.public_length; + if (server_pub_key.len != ksl) return error.TlsIllegalParameter; + const sk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..ksl].*) catch + return error.TlsDecryptFailure; + @memcpy(ks.sk_buf[0..sk.len], &sk); + ks.sk_len = sk.len; + }, else => return error.TlsIllegalParameter, } } @@ -1877,7 +1855,7 @@ fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec { /// aes128-gcm: 138 MiB/s /// aes256-gcm: 120 MiB/s const cipher_suites = if (crypto.core.aes.has_hardware_support) - enum_array(tls.CipherSuite, &.{ + array(u16, tls.CipherSuite, .{ .AEGIS_128L_SHA256, .AEGIS_256_SHA512, .AES_128_GCM_SHA256, @@ -1888,7 +1866,7 @@ const cipher_suites = if (crypto.core.aes.has_hardware_support) .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, }) else - enum_array(tls.CipherSuite, &.{ + array(u16, tls.CipherSuite, .{ .CHACHA20_POLY1305_SHA256, .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, .AEGIS_128L_SHA256, From 90a761c18659919f466715714395ce27425c45f7 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Mon, 4 Nov 2024 19:57:53 -0500 Subject: [PATCH 06/14] std.crypto.tls: make verify data checks timing safe --- lib/std/crypto/tls.zig | 3 ++- lib/std/crypto/tls/Client.zig | 15 ++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 6479c77d75..8c7d3fcdb6 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -442,7 +442,7 @@ pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type, compti transcript_hash: A.Hash, version: union { tls_1_2: struct { - server_verify_data: [12]u8, + expected_server_verify_data: [A.verify_data_length]u8, app_cipher: A.Tls_1_2, }, tls_1_3: struct { @@ -479,6 +479,7 @@ pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type, comp pub const record_iv_length = explicit_iv_length; pub const mac_length = AEAD.tag_length; pub const mac_key_length = Hmac.key_length_min; + pub const verify_data_length = 12; tls_1_2: Tls_1_2, tls_1_3: Tls_1_3, diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index a8624fd03f..6d9a75dc22 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -662,21 +662,20 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In &.{ "key expansion", &server_hello_rand, &client_hello_rand }, @sizeOf(P.Tls_1_2), ); - const verify_data_len = 12; const client_verify_cleartext = .{@intFromEnum(tls.HandshakeType.finished)} ++ array(u24, u8, hmacExpandLabel( P.Hmac, &master_secret, &.{ "client finished", &p.transcript_hash.peek() }, - verify_data_len, + P.verify_data_length, )); p.transcript_hash.update(&client_verify_cleartext); p.version = .{ .tls_1_2 = .{ - .server_verify_data = hmacExpandLabel( + .expected_server_verify_data = hmacExpandLabel( P.Hmac, &master_secret, &.{ "server finished", &p.transcript_hash.finalResult() }, - verify_data_len, + P.verify_data_length, ), .app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block), } }; @@ -747,10 +746,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In .tls_1_3 => { const pv = &p.version.tls_1_3; const P = @TypeOf(p.*).A; + try hsd.ensure(P.Hmac.mac_length); const finished_digest = p.transcript_hash.peek(); p.transcript_hash.update(wrapped_handshake); const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, pv.server_finished_key); - if (!mem.eql(u8, &expected_server_verify_data, hsd.buf)) return error.TlsDecryptError; + if (!std.crypto.timing_safe.eql([P.Hmac.mac_length]u8, expected_server_verify_data, hsd.array(P.Hmac.mac_length).*)) return error.TlsDecryptError; const handshake_hash = p.transcript_hash.finalResult(); const verify_data = tls.hmac(P.Hmac, &handshake_hash, pv.client_finished_key); const out_cleartext = .{@intFromEnum(tls.HandshakeType.finished)} ++ @@ -788,8 +788,9 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In }, .tls_1_2 => { const pv = &p.version.tls_1_2; - try hsd.ensure(12); - if (!std.mem.eql(u8, hsd.array(12), &pv.server_verify_data)) return error.TlsDecryptError; + const P = @TypeOf(p.*).A; + try hsd.ensure(P.verify_data_length); + if (!std.crypto.timing_safe.eql([P.verify_data_length]u8, pv.expected_server_verify_data, hsd.array(P.verify_data_length).*)) return error.TlsDecryptError; break :app_cipher @unionInit(tls.ApplicationCipher, @tagName(tag), .{ .tls_1_2 = pv.app_cipher }); }, else => unreachable, From 485f20a10ab489274b71b0d3106dcf81dbe3ac15 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Mon, 4 Nov 2024 20:45:18 -0500 Subject: [PATCH 07/14] std.crypto.tls: remove hardcoded initial loop This was preventing TLSv1.2 from working in some cases, because servers are allowed to send multiple handshake messages in the first handshake record, whereas this inital loop was assuming that it only contained a server hello. --- lib/std/crypto/tls/Client.zig | 313 +++++++++++++++------------------- 1 file changed, 142 insertions(+), 171 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 6d9a75dc22..2a0d49ca69 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -214,158 +214,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In try stream.writevAll(&iovecs); } - const client_hello_bytes1 = cleartext_header[tls.record_header_len..]; - var tls_version: tls.ProtocolVersion = undefined; - var cipher_suite_tag: tls.CipherSuite = undefined; - var handshake_cipher: tls.HandshakeCipher = undefined; - var handshake_buffer: [8000]u8 = undefined; - var d: tls.Decoder = .{ .buf = &handshake_buffer }; - { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const ct = d.decode(tls.ContentType); - d.skip(2); // legacy_record_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - const server_hello_fragment = d.buf[d.idx..][0..record_len]; - var ptd = try d.sub(record_len); - switch (ct) { - .alert => { - try ptd.ensure(2); - const level = ptd.decode(tls.AlertLevel); - const desc = ptd.decode(tls.AlertDescription); - _ = level; - - // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - try ptd.ensure(4); - const handshake_type = ptd.decode(tls.HandshakeType); - if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; - const length = ptd.decode(u24); - var hsd = try ptd.sub(length); - try hsd.ensure(2 + 32 + 1); - const legacy_version = hsd.decode(u16); - @memcpy(&server_hello_rand, hsd.array(32)); - if (mem.eql(u8, &server_hello_rand, &tls.hello_retry_request_sequence)) { - // This is a HelloRetryRequest message. This client implementation - // does not expect to get one. - return error.TlsUnexpectedMessage; - } - const legacy_session_id_echo_len = hsd.decode(u8); - try hsd.ensure(legacy_session_id_echo_len + 2 + 1); - const legacy_session_id_echo = hsd.slice(legacy_session_id_echo_len); - cipher_suite_tag = hsd.decode(tls.CipherSuite); - hsd.skip(1); // legacy_compression_method - var supported_version: ?u16 = null; - if (!hsd.eof()) { - try hsd.ensure(2); - const extensions_size = hsd.decode(u16); - var all_extd = try hsd.sub(extensions_size); - while (!all_extd.eof()) { - try all_extd.ensure(2 + 2); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - var extd = try all_extd.sub(ext_size); - switch (et) { - .supported_versions => { - if (supported_version) |_| return error.TlsIllegalParameter; - try extd.ensure(2); - supported_version = extd.decode(u16); - }, - .key_share => { - if (key_share.getSharedSecret()) |_| return error.TlsIllegalParameter; - try extd.ensure(4); - const named_group = extd.decode(tls.NamedGroup); - const key_size = extd.decode(u16); - try extd.ensure(key_size); - try key_share.exchange(named_group, extd.slice(key_size)); - }, - else => {}, - } - } - } - - tls_version = @enumFromInt(supported_version orelse legacy_version); - switch (tls_version) { - .tls_1_3 => if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter, - .tls_1_2 => if (mem.eql(u8, server_hello_rand[24..31], "DOWNGRD") and - server_hello_rand[31] >> 1 == 0x00) return error.TlsIllegalParameter, - else => return error.TlsIllegalParameter, - } - - switch (cipher_suite_tag) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_256_SHA512, - .AEGIS_128L_SHA256, - - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - => |tag| { - handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag.with()), .{ - .transcript_hash = .init(.{}), - .version = undefined, - }); - const p = &@field(handshake_cipher, @tagName(tag.with())); - p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 - p.transcript_hash.update(host); // Client Hello part 2 - p.transcript_hash.update(server_hello_fragment); - }, - - else => return error.TlsIllegalParameter, - } - switch (tls_version) { - .tls_1_3 => switch (cipher_suite_tag) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_256_SHA512, - .AEGIS_128L_SHA256, - => |tag| { - const sk = key_share.getSharedSecret() orelse return error.TlsIllegalParameter; - const p = &@field(handshake_cipher, @tagName(tag.with())); - const P = @TypeOf(p.*).A; - const hello_hash = p.transcript_hash.peek(); - const zeroes = [1]u8{0} ** P.Hash.digest_length; - const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = tls.emptyHash(P.Hash); - p.version = .{ .tls_1_3 = undefined }; - const pv = &p.version.tls_1_3; - const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); - pv.handshake_secret = P.Hkdf.extract(&hs_derived_secret, sk); - const ap_derived_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); - pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); - const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - pv.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); - pv.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); - pv.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - pv.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - pv.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - pv.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - else => return error.TlsIllegalParameter, - }, - .tls_1_2 => switch (cipher_suite_tag) { - .ECDHE_RSA_WITH_AES_128_GCM_SHA256, - .ECDHE_RSA_WITH_AES_256_GCM_SHA384, - .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - => {}, - else => return error.TlsIllegalParameter, - }, - else => return error.TlsIllegalParameter, - } - }, - else => return error.TlsUnexpectedMessage, - } - } - // This is used for two purposes: // * Detect whether a certificate is the first one presented, in which case // we need to verify the host name. @@ -384,13 +233,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In /// Application cipher is in use application, }; - var pending_cipher_state: CipherState = switch (tls_version) { - .tls_1_3 => .handshake, - .tls_1_2 => .cleartext, - else => unreachable, - }; - var cipher_state: CipherState = .cleartext; + var pending_cipher_state: CipherState = .cleartext; + var cipher_state = pending_cipher_state; const HandshakeState = enum { + /// In this state we expect only a server hello message. + hello, /// In this state we expect only an encrypted_extensions message. encrypted_extensions, /// In this state we expect certificate handshake messages. @@ -404,15 +251,14 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In /// In this state, we expect only the finished handshake message. finished, }; - var handshake_state: HandshakeState = switch (tls_version) { - .tls_1_3 => .encrypted_extensions, - .tls_1_2 => .certificate, - else => unreachable, - }; - var cleartext_bufs: [2][8000]u8 = undefined; + var handshake_state: HandshakeState = .hello; + var handshake_cipher: tls.HandshakeCipher = undefined; var main_cert_pub_key: CertificatePublicKey = undefined; const now_sec = std.time.timestamp(); + var cleartext_bufs: [2][8000]u8 = undefined; + var handshake_buffer: [8000]u8 = undefined; + var d: tls.Decoder = .{ .buf = &handshake_buffer }; while (true) { try d.readAtLeastOurAmt(stream, tls.record_header_len); const record_header = d.buf[d.idx..][0..tls.record_header_len]; @@ -526,11 +372,132 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In var hsd = try ctd.sub(handshake_len); const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; switch (handshake_type) { + .server_hello => { + if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; + if (handshake_state != .hello) return error.TlsUnexpectedMessage; + try hsd.ensure(2 + 32 + 1); + const legacy_version = hsd.decode(u16); + @memcpy(&server_hello_rand, hsd.array(32)); + if (mem.eql(u8, &server_hello_rand, &tls.hello_retry_request_sequence)) { + // This is a HelloRetryRequest message. This client implementation + // does not expect to get one. + return error.TlsUnexpectedMessage; + } + const legacy_session_id_echo_len = hsd.decode(u8); + try hsd.ensure(legacy_session_id_echo_len + 2 + 1); + const legacy_session_id_echo = hsd.slice(legacy_session_id_echo_len); + const cipher_suite_tag = hsd.decode(tls.CipherSuite); + hsd.skip(1); // legacy_compression_method + var supported_version: ?u16 = null; + if (!hsd.eof()) { + try hsd.ensure(2); + const extensions_size = hsd.decode(u16); + var all_extd = try hsd.sub(extensions_size); + while (!all_extd.eof()) { + try all_extd.ensure(2 + 2); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + var extd = try all_extd.sub(ext_size); + switch (et) { + .supported_versions => { + if (supported_version) |_| return error.TlsIllegalParameter; + try extd.ensure(2); + supported_version = extd.decode(u16); + }, + .key_share => { + if (key_share.getSharedSecret()) |_| return error.TlsIllegalParameter; + try extd.ensure(4); + const named_group = extd.decode(tls.NamedGroup); + const key_size = extd.decode(u16); + try extd.ensure(key_size); + try key_share.exchange(named_group, extd.slice(key_size)); + }, + else => {}, + } + } + } + + tls_version = @enumFromInt(supported_version orelse legacy_version); + switch (tls_version) { + .tls_1_3 => if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter, + .tls_1_2 => if (mem.eql(u8, server_hello_rand[24..31], "DOWNGRD") and + server_hello_rand[31] >> 1 == 0x00) return error.TlsIllegalParameter, + else => return error.TlsIllegalParameter, + } + + switch (cipher_suite_tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_256_SHA512, + .AEGIS_128L_SHA256, + + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => |tag| { + handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag.with()), .{ + .transcript_hash = .init(.{}), + .version = undefined, + }); + const p = &@field(handshake_cipher, @tagName(tag.with())); + p.transcript_hash.update(cleartext_header[tls.record_header_len..]); // Client Hello part 1 + p.transcript_hash.update(host); // Client Hello part 2 + p.transcript_hash.update(wrapped_handshake); + }, + + else => return error.TlsIllegalParameter, + } + switch (tls_version) { + .tls_1_3 => { + switch (cipher_suite_tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_256_SHA512, + .AEGIS_128L_SHA256, + => |tag| { + const sk = key_share.getSharedSecret() orelse return error.TlsIllegalParameter; + const p = &@field(handshake_cipher, @tagName(tag.with())); + const P = @TypeOf(p.*).A; + const hello_hash = p.transcript_hash.peek(); + const zeroes = [1]u8{0} ** P.Hash.digest_length; + const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = tls.emptyHash(P.Hash); + p.version = .{ .tls_1_3 = undefined }; + const pv = &p.version.tls_1_3; + const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); + pv.handshake_secret = P.Hkdf.extract(&hs_derived_secret, sk); + const ap_derived_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); + pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); + const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); + pv.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); + pv.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); + pv.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + pv.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + pv.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + pv.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, + else => return error.TlsIllegalParameter, + } + pending_cipher_state = .handshake; + handshake_state = .encrypted_extensions; + }, + .tls_1_2 => switch (cipher_suite_tag) { + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => handshake_state = .certificate, + else => return error.TlsIllegalParameter, + }, + else => return error.TlsIllegalParameter, + } + }, .encrypted_extensions => { if (tls_version != .tls_1_3) return error.TlsUnexpectedMessage; if (cipher_state != .handshake) return error.TlsUnexpectedMessage; if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; - handshake_state = .certificate; switch (handshake_cipher) { inline else => |*p| p.transcript_hash.update(wrapped_handshake), } @@ -548,16 +515,18 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In else => {}, } } + handshake_state = .certificate; }, .certificate => cert: { - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } + if (cipher_state == .application) return error.TlsUnexpectedMessage; switch (handshake_state) { .certificate => {}, .trust_chain_established => break :cert, else => return error.TlsUnexpectedMessage, } + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } switch (tls_version) { .tls_1_3 => { @@ -614,7 +583,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage; if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; switch (handshake_state) { - .trust_chain_established => handshake_state = .server_hello_done, + .trust_chain_established => {}, .certificate => return error.TlsCertificateNotVerified, else => return error.TlsUnexpectedMessage, } @@ -631,12 +600,12 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const server_pub_key = hsd.slice(key_size); try main_cert_pub_key.verifySignature(&hsd, &.{ &client_hello_rand, &server_hello_rand, hsd.buf[0..hsd.idx] }); try key_share.exchange(named_group, server_pub_key); + handshake_state = .server_hello_done; }, .server_hello_done => { if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage; if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; if (handshake_state != .server_hello_done) return error.TlsUnexpectedMessage; - handshake_state = .finished; const client_key_exchange_msg = .{@intFromEnum(tls.ContentType.handshake)} ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ @@ -680,7 +649,6 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In .app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block), } }; const pv = &p.version.tls_1_2; - pending_cipher_state = .application; const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) nonce: { @@ -715,12 +683,14 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In }, } write_seq += 1; + pending_cipher_state = .application; + handshake_state = .finished; }, .certificate_verify => { if (tls_version != .tls_1_3) return error.TlsUnexpectedMessage; if (cipher_state != .handshake) return error.TlsUnexpectedMessage; switch (handshake_state) { - .trust_chain_established => handshake_state = .finished, + .trust_chain_established => {}, .certificate => return error.TlsCertificateNotVerified, else => return error.TlsUnexpectedMessage, } @@ -733,6 +703,7 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In p.transcript_hash.update(wrapped_handshake); }, } + handshake_state = .finished; }, .finished => { if (cipher_state == .cleartext) return error.TlsUnexpectedMessage; From d86a8aedd5674819ec4af1bfc8a81b3fef91fd85 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Mon, 4 Nov 2024 22:43:31 -0500 Subject: [PATCH 08/14] std.crypto.tls: increase handshake buffer sizes --- lib/std/crypto/tls/Client.zig | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 2a0d49ca69..4665a0ba38 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -256,8 +256,8 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In var main_cert_pub_key: CertificatePublicKey = undefined; const now_sec = std.time.timestamp(); - var cleartext_bufs: [2][8000]u8 = undefined; - var handshake_buffer: [8000]u8 = undefined; + var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined; + var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined; var d: tls.Decoder = .{ .buf = &handshake_buffer }; while (true) { try d.readAtLeastOurAmt(stream, tls.record_header_len); From de53e6e4f2dc7a41dc50b309fee87e06475e4838 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Tue, 5 Nov 2024 01:37:12 -0500 Subject: [PATCH 09/14] std.crypto.tls: improve debuggability of encrypted connections By default, programs built in debug mode that open a https connection will append secrets to the file specified in the SSLKEYLOGFILE environment variable to allow protocol debugging by external programs. --- lib/std/crypto/tls/Client.zig | 171 +++++++++++++++++++++++++++++----- lib/std/http/Client.zig | 29 ++++-- lib/std/std.zig | 5 + 3 files changed, 174 insertions(+), 31 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 4665a0ba38..e10a7273c9 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -33,7 +33,7 @@ received_close_notify: bool, /// This makes the application vulnerable to truncation attacks unless the /// application layer itself verifies that the amount of data received equals /// the amount of data expected, such as HTTP with the Content-Length header. -allow_truncation_attacks: bool = false, +allow_truncation_attacks: bool, application_cipher: tls.ApplicationCipher, /// The size is enough to contain exactly one TLSCiphertext record. /// This buffer is segmented into four parts: @@ -44,6 +44,24 @@ application_cipher: tls.ApplicationCipher, /// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and /// `partial_ciphertext_end` describe the span of the segments. partially_read_buffer: [tls.max_ciphertext_record_len]u8, +/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other +/// programs with access to that file to decrypt all traffic over this connection. +ssl_key_log: ?struct { + client_key_seq: u64, + server_key_seq: u64, + client_random: [32]u8, + file: std.fs.File, + + fn clientCounter(key_log: *@This()) u64 { + defer key_log.client_key_seq += 1; + return key_log.client_key_seq; + } + + fn serverCounter(key_log: *@This()) u64 { + defer key_log.server_key_seq += 1; + return key_log.server_key_seq; + } +}, /// This is an example of the type that is needed by the read and write /// functions. It can have any fields but it must at least have these @@ -88,6 +106,32 @@ pub const StreamInterface = struct { } }; +pub const Options = struct { + /// How to perform host verification of server certificates. + host: union(enum) { + /// No host verification is performed, which prevents a trusted connection from + /// being established. + no_verification, + /// Verify that the server certificate was issues for a given host. + explicit: []const u8, + }, + /// How to verify the authenticity of server certificates. + ca: union(enum) { + /// No ca verification is performed, which prevents a trusted connection from + /// being established. + no_verification, + /// Verify that the server certificate is a valid self-signed certificate. + /// This provides no authorization guarantees, as anyone can create a + /// self-signed certificate. + self_signed, + /// Verify that the server certificate is authorized by a given ca bundle. + bundle: Certificate.Bundle, + }, + /// If non-null, ssl secrets are logged to this file. Creating such a log file allows + /// other programs with access to that file to decrypt all traffic over this connection. + ssl_key_log_file: ?std.fs.File = null, +}; + pub fn InitError(comptime Stream: type) type { return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ InsufficientEntropy, @@ -140,12 +184,17 @@ pub fn InitError(comptime Stream: type) type { /// must conform to `StreamInterface`. /// /// `host` is only borrowed during this function call. -pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) InitError(@TypeOf(stream))!Client { +pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client { + const host = switch (options.host) { + .no_verification => "", + .explicit => |host| host, + }; const host_len: u16 = @intCast(host.len); var random_buffer: [128]u8 = undefined; crypto.random.bytes(&random_buffer); const client_hello_rand = random_buffer[0..32].*; + var key_seq: u64 = 0; var server_hello_rand: [32]u8 = undefined; const legacy_session_id = random_buffer[32..64].*; @@ -179,15 +228,21 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In array(u16, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++ int(u16, @intFromEnum(tls.NamedGroup.x25519)) ++ array(u16, u8, key_share.x25519_kp.public_key), - )) ++ int(u16, @intFromEnum(tls.ExtensionType.server_name)) ++ + )); + const server_name_extension = int(u16, @intFromEnum(tls.ExtensionType.server_name)) ++ int(u16, 2 + 1 + 2 + host_len) ++ // byte length of this extension payload int(u16, 1 + 2 + host_len) ++ // server_name_list byte count .{0x00} ++ // name_type int(u16, host_len); + const server_name_extension_len = switch (options.host) { + .no_verification => 0, + .explicit => server_name_extension.len + host_len, + }; const extensions_header = - int(u16, @intCast(extensions_payload.len + host_len)) ++ - extensions_payload; + int(u16, @intCast(extensions_payload.len + server_name_extension_len)) ++ + extensions_payload ++ + server_name_extension; const client_hello = int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ @@ -198,20 +253,24 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In extensions_header; const out_handshake = .{@intFromEnum(tls.HandshakeType.client_hello)} ++ - int(u24, @intCast(client_hello.len + host_len)) ++ + int(u24, @intCast(client_hello.len - server_name_extension.len + server_name_extension_len)) ++ client_hello; - const cleartext_header = .{@intFromEnum(tls.ContentType.handshake)} ++ + const cleartext_header_buf = .{@intFromEnum(tls.ContentType.handshake)} ++ int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_0)) ++ - int(u16, @intCast(out_handshake.len + host_len)) ++ + int(u16, @intCast(out_handshake.len - server_name_extension.len + server_name_extension_len)) ++ out_handshake; + const cleartext_header = switch (options.host) { + .no_verification => cleartext_header_buf[0 .. cleartext_header_buf.len - server_name_extension.len], + .explicit => &cleartext_header_buf, + }; { var iovecs = [_]std.posix.iovec_const{ - .{ .base = &cleartext_header, .len = cleartext_header.len }, + .{ .base = cleartext_header.ptr, .len = cleartext_header.len }, .{ .base = host.ptr, .len = host.len }, }; - try stream.writevAll(&iovecs); + try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]); } var tls_version: tls.ProtocolVersion = undefined; @@ -472,6 +531,12 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); + if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + .client_random = &client_hello_rand, + }, .{ + .SERVER_HANDSHAKE_TRAFFIC_SECRET = &server_secret, + .CLIENT_HANDSHAKE_TRAFFIC_SECRET = &client_secret, + }); pv.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); pv.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); pv.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); @@ -544,6 +609,13 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const cert_size = certs_decoder.decode(u24); const certd = try certs_decoder.sub(cert_size); + if (tls_version == .tls_1_3) { + try certs_decoder.ensure(2); + const total_ext_size = certs_decoder.decode(u16); + const all_extd = try certs_decoder.sub(total_ext_size); + _ = all_extd; + } + const subject_cert: Certificate = .{ .buffer = certd.buf, .index = @intCast(certd.idx), @@ -551,7 +623,10 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const subject = try subject_cert.parse(); if (cert_index == 0) { // Verify the host on the first certificate. - try subject.verifyHostName(host); + switch (options.host) { + .no_verification => {}, + .explicit => try subject.verifyHostName(host), + } // Keep track of the public key for the // certificate_verify message later. @@ -560,23 +635,27 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In try prev_cert.verify(subject, now_sec); } - if (ca_bundle.verify(subject, now_sec)) |_| { - handshake_state = .trust_chain_established; - break :cert; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - else => |e| return e, + switch (options.ca) { + .no_verification => { + handshake_state = .trust_chain_established; + break :cert; + }, + .self_signed => { + try subject.verify(subject, now_sec); + handshake_state = .trust_chain_established; + break :cert; + }, + .bundle => |ca_bundle| if (ca_bundle.verify(subject, now_sec)) |_| { + handshake_state = .trust_chain_established; + break :cert; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + else => |e| return e, + }, } prev_cert = subject; cert_index += 1; - - if (tls_version == .tls_1_3) { - try certs_decoder.ensure(2); - const total_ext_size = certs_decoder.decode(u16); - const all_extd = try certs_decoder.sub(total_ext_size); - _ = all_extd; - } } }, .server_key_exchange => { @@ -625,6 +704,11 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In &client_hello_rand, &server_hello_rand, }, 48); + if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + .client_random = &client_hello_rand, + }, .{ + .CLIENT_RANDOM = &master_secret, + }); const key_block = hmacExpandLabel( P.Hmac, &master_secret, @@ -748,6 +832,14 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + .counter = key_seq, + .client_random = &client_hello_rand, + }, .{ + .SERVER_TRAFFIC_SECRET = &server_secret, + .CLIENT_TRAFFIC_SECRET = &client_secret, + }); + key_seq += 1; break :app_cipher @unionInit(tls.ApplicationCipher, @tagName(tag), .{ .tls_1_3 = .{ .client_secret = client_secret, .server_secret = server_secret, @@ -784,8 +876,15 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In .partial_ciphertext_idx = 0, .partial_ciphertext_end = @intCast(leftover.len), .received_close_notify = false, + .allow_truncation_attacks = false, .application_cipher = app_cipher, .partially_read_buffer = undefined, + .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{ + .client_key_seq = key_seq, + .server_key_seq = key_seq, + .client_random = client_hello_rand, + .file = key_log_file, + } else null, }; @memcpy(client.partially_read_buffer[0..leftover.len], leftover); return client; @@ -1358,6 +1457,12 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove const pv = &p.tls_1_3; const P = @TypeOf(p.*); const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ + .counter = key_log.serverCounter(), + .client_random = &key_log.client_random, + }, .{ + .SERVER_TRAFFIC_SECRET = &server_secret, + }); pv.server_secret = server_secret; pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); @@ -1372,6 +1477,12 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove const pv = &p.tls_1_3; const P = @TypeOf(p.*); const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ + .counter = key_log.clientCounter(), + .client_random = &key_log.client_random, + }, .{ + .CLIENT_TRAFFIC_SECRET = &client_secret, + }); pv.client_secret = client_secret; pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); @@ -1426,6 +1537,18 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove } } +fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void { + const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false; + defer if (locked) key_log_file.unlock(); + key_log_file.seekFromEnd(0) catch {}; + inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.writer().print("{s}" ++ + (if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {} {}\n", .{field.name} ++ + (if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{ + std.fmt.fmtSliceHexLower(context.client_random), + std.fmt.fmtSliceHexLower(@field(secrets, field.name)), + }) catch {}; +} + fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { const saved_buf = frag[in..]; if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 6e95995ee0..cddc6297c9 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -388,6 +388,7 @@ pub const Connection = struct { // try to cleanly close the TLS connection, for any server that cares. _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close(); allocator.destroy(conn.tls_client); } @@ -566,7 +567,7 @@ pub const Response = struct { .reason = undefined, .version = undefined, .keep_alive = false, - .parser = proto.HeadersParser.init(&header_buffer), + .parser = .init(&header_buffer), }; @memcpy(header_buffer[0..response_bytes.len], response_bytes); @@ -610,7 +611,7 @@ pub const Response = struct { } pub fn iterateHeaders(r: Response) http.HeaderIterator { - return http.HeaderIterator.init(r.parser.get()); + return .init(r.parser.get()); } test iterateHeaders { @@ -628,7 +629,7 @@ pub const Response = struct { .reason = undefined, .version = undefined, .keep_alive = false, - .parser = proto.HeadersParser.init(&header_buffer), + .parser = .init(&header_buffer), }; @memcpy(header_buffer[0..response_bytes.len], response_bytes); @@ -771,7 +772,7 @@ pub const Request = struct { req.client.connection_pool.release(req.client.allocator, req.connection.?); req.connection = null; - var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); + var server_header: std.heap.FixedBufferAllocator = .init(req.response.parser.header_bytes_buffer); defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); @@ -1354,7 +1355,21 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); errdefer client.allocator.destroy(conn.data.tls_client); - conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; + const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: { + const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) { + error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null, + error.OutOfMemory => return error.OutOfMemory, + }; + defer client.allocator.free(ssl_key_log_path); + break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ .truncate = false }) catch null; + } else null; + errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close(); + + conn.data.tls_client.* = std.crypto.tls.Client.init(stream, .{ + .host = .{ .explicit = host }, + .ca = .{ .bundle = client.ca_bundle }, + .ssl_key_log_file = ssl_key_log_file, + }) catch return error.TlsInitializationFailed; // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. conn.data.tls_client.allow_truncation_attacks = true; @@ -1620,7 +1635,7 @@ pub fn open( } } - var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); + var server_header: std.heap.FixedBufferAllocator = .init(options.server_header_buffer); const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { @@ -1654,7 +1669,7 @@ pub fn open( .status = undefined, .reason = undefined, .keep_alive = undefined, - .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), + .parser = .init(server_header.buffer[server_header.end_index..]), }, .headers = options.headers, .extra_headers = options.extra_headers, diff --git a/lib/std/std.zig b/lib/std/std.zig index 6dbb4c0843..cc61111746 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -146,6 +146,11 @@ pub const Options = struct { /// make a HTTPS connection. http_disable_tls: bool = false, + /// This enables `std.http.Client` to log ssl secrets to the file specified by the SSLKEYLOGFILE + /// env var. Creating such a log file allows other programs with access to that file to decrypt + /// all `std.http.Client` traffic made by this program. + http_enable_ssl_key_log_file: bool = @import("builtin").mode == .Debug, + side_channels_mitigations: crypto.SideChannelsMitigations = crypto.default_side_channels_mitigations, }; From a6ede7ba86987b9ae2bb6b8aac60f66af56e7b08 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Tue, 5 Nov 2024 02:24:14 -0500 Subject: [PATCH 10/14] std.crypto.tls: support handshake fragments --- lib/std/crypto/tls/Client.zig | 54 +++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index e10a7273c9..922f7b66cc 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -274,13 +274,14 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client } var tls_version: tls.ProtocolVersion = undefined; - // This is used for two purposes: + // These are used for two purposes: // * Detect whether a certificate is the first one presented, in which case // we need to verify the host name. + var cert_index: usize = 0; // * Flip back and forth between the two cleartext buffers in order to keep // the previous certificate in memory so that it can be verified by the // next one. - var cert_index: usize = 0; + var cert_buf_index: usize = 0; var write_seq: u64 = 0; var read_seq: u64 = 0; var prev_cert: Certificate.Parsed = undefined; @@ -315,10 +316,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client var main_cert_pub_key: CertificatePublicKey = undefined; const now_sec = std.time.timestamp(); + var cleartext_fragment_start: usize = 0; + var cleartext_fragment_end: usize = 0; var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined; var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined; var d: tls.Decoder = .{ .buf = &handshake_buffer }; - while (true) { + fragment: while (true) { try d.readAtLeastOurAmt(stream, tls.record_header_len); const record_header = d.buf[d.idx..][0..tls.record_header_len]; const record_ct = d.decode(tls.ContentType); @@ -332,15 +335,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client std.debug.assert(tls_version == .tls_1_3); if (record_ct != .application_data) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); - const cleartext_buf = &cleartext_bufs[cert_index % 2]; - const cleartext = cleartext: switch (handshake_cipher) { + const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; + switch (handshake_cipher) { inline else => |*p| { const pv = &p.version.tls_1_3; const P = @TypeOf(p.*).A; if (record_len < P.AEAD.tag_length) return error.TlsRecordOverflow; const ciphertext = record_decoder.slice(record_len - P.AEAD.tag_length); - if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; - const cleartext = cleartext_buf[0..ciphertext.len]; + const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..]; + if (ciphertext.len > cleartext_fragment_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_fragment_buf[0..ciphertext.len]; const auth_tag = record_decoder.array(P.AEAD.tag_length).*; const nonce = if (builtin.zig_backend == .stage2_x86_64 and P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) @@ -357,27 +361,29 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client }; P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch return error.TlsBadRecordMac; - break :cleartext mem.trimRight(u8, cleartext, "\x00"); + cleartext_fragment_end += std.mem.trimRight(u8, cleartext, "\x00").len; }, - }; + } read_seq += 1; - const ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); + cleartext_fragment_end -= 1; + const ct: tls.ContentType = @enumFromInt(cleartext_buf[cleartext_fragment_end]); if (ct != .handshake) return error.TlsUnexpectedMessage; - break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext[0 .. cleartext.len - 1])), ct }; + break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct }; }, .application => { std.debug.assert(tls_version == .tls_1_2); if (record_ct != .handshake) return error.TlsUnexpectedMessage; try record_decoder.ensure(record_len); - const cleartext_buf = &cleartext_bufs[cert_index % 2]; - const cleartext = cleartext: switch (handshake_cipher) { + const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; + switch (handshake_cipher) { inline else => |*p| { const pv = &p.version.tls_1_2; const P = @TypeOf(p.*).A; if (record_len < P.record_iv_length + P.mac_length) return error.TlsRecordOverflow; const message_len: u16 = record_len - P.record_iv_length - P.mac_length; - if (message_len > cleartext_buf.len) return error.TlsRecordOverflow; - const cleartext = cleartext_buf[0..message_len]; + const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..]; + if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_fragment_buf[0..message_len]; const ad = std.mem.toBytes(big(read_seq)) ++ record_header[0 .. 1 + 2] ++ std.mem.toBytes(big(message_len)); @@ -400,16 +406,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client const ciphertext = record_decoder.slice(message_len); const auth_tag = record_decoder.array(P.mac_length); P.AEAD.decrypt(cleartext, ciphertext, auth_tag.*, ad, nonce, pv.app_cipher.server_write_key) catch return error.TlsBadRecordMac; - break :cleartext cleartext; + cleartext_fragment_end += message_len; }, - }; + } read_seq += 1; - break :content .{ tls.Decoder.fromTheirSlice(cleartext), record_ct }; + break :content .{ tls.Decoder.fromTheirSlice(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end]), record_ct }; }, }; switch (ct) { .alert => { - try ctd.ensure(2); + ctd.ensure(2) catch continue :fragment; const level = ctd.decode(tls.AlertLevel); const desc = ctd.decode(tls.AlertDescription); _ = level; @@ -420,15 +426,15 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client return error.TlsUnexpectedMessage; }, .change_cipher_spec => { - try ctd.ensure(1); + ctd.ensure(1) catch continue :fragment; if (ctd.decode(tls.ChangeCipherSpecType) != .change_cipher_spec) return error.TlsIllegalParameter; cipher_state = pending_cipher_state; }, .handshake => while (true) { - try ctd.ensure(4); + ctd.ensure(4) catch continue :fragment; const handshake_type = ctd.decode(tls.HandshakeType); const handshake_len = ctd.decode(u24); - var hsd = try ctd.sub(handshake_len); + var hsd = ctd.sub(handshake_len) catch continue :fragment; const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; switch (handshake_type) { .server_hello => { @@ -657,6 +663,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client prev_cert = subject; cert_index += 1; } + cert_buf_index += 1; }, .server_key_exchange => { if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage; @@ -892,9 +899,12 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client else => return error.TlsUnexpectedMessage, } if (ctd.eof()) break; + cleartext_fragment_start = ctd.idx; }, else => return error.TlsUnexpectedMessage, } + cleartext_fragment_start = 0; + cleartext_fragment_end = 0; } } From fbaefcaa946b74b0702cf5a60e76b694d870d04e Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Tue, 5 Nov 2024 04:19:35 -0500 Subject: [PATCH 11/14] std.crypto.tls: support the same key sizes as certificate verification --- lib/std/crypto/tls/Client.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 922f7b66cc..bbee90275f 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -1840,7 +1840,7 @@ const CertificatePublicKey = struct { const exponent = components.exponent; const modulus = components.modulus; switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { + inline 128, 256, 384, 512 => |modulus_len| { const key: PublicKey = try .fromBytes(exponent, modulus); const sig = RsaSignature.fromBytes(modulus_len, encoded_sig); try RsaSignature.concatVerify(modulus_len, sig, msg, key, Hash); From a4e88abf042f89cf98ff0b54dfe82eb3ad6eaa0b Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Tue, 5 Nov 2024 22:42:49 -0500 Subject: [PATCH 12/14] std.crypto.tls: advertise all supported signature algorithms --- lib/std/crypto/tls.zig | 11 +++++++++++ lib/std/crypto/tls/Client.zig | 7 +++++++ 2 files changed, 18 insertions(+) diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index 8c7d3fcdb6..74113225cb 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -266,6 +266,17 @@ pub const SignatureScheme = enum(u16) { rsa_pkcs1_sha1 = 0x0201, ecdsa_sha1 = 0x0203, + ecdsa_brainpoolP256r1tls13_sha256 = 0x081a, + ecdsa_brainpoolP384r1tls13_sha384 = 0x081b, + ecdsa_brainpoolP512r1tls13_sha512 = 0x081c, + + rsa_sha224 = 0x0301, + dsa_sha224 = 0x0302, + ecdsa_sha224 = 0x0303, + dsa_sha256 = 0x0402, + dsa_sha384 = 0x0502, + dsa_sha512 = 0x0602, + _, }; diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index bbee90275f..fd52e34137 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -209,9 +209,16 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client })) ++ tls.extension(.signature_algorithms, array(u16, tls.SignatureScheme, .{ .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, .rsa_pss_rsae_sha256, .rsa_pss_rsae_sha384, .rsa_pss_rsae_sha512, + .rsa_pss_pss_sha256, + .rsa_pss_pss_sha384, + .rsa_pss_pss_sha512, + .rsa_pkcs1_sha1, .ed25519, })) ++ tls.extension(.supported_groups, array(u16, tls.NamedGroup, .{ .x25519_ml_kem768, From 75adba7cb9501f33453275c187bcd7f4b11eaa9d Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Wed, 6 Nov 2024 00:14:27 -0500 Subject: [PATCH 13/14] std.crypto.tls: add support for secp384r1 key share --- lib/std/crypto/tls/Client.zig | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index fd52e34137..20ea485049 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -191,14 +191,14 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client }; const host_len: u16 = @intCast(host.len); - var random_buffer: [128]u8 = undefined; + var random_buffer: [176]u8 = undefined; crypto.random.bytes(&random_buffer); const client_hello_rand = random_buffer[0..32].*; var key_seq: u64 = 0; var server_hello_rand: [32]u8 = undefined; const legacy_session_id = random_buffer[32..64].*; - var key_share = KeyShare.init(random_buffer[64..128].*) catch |err| switch (err) { + var key_share = KeyShare.init(random_buffer[64..176].*) catch |err| switch (err) { // Only possible to happen if the seed is all zeroes. error.IdentityElement => return error.InsufficientEntropy, }; @@ -223,6 +223,7 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client })) ++ tls.extension(.supported_groups, array(u16, tls.NamedGroup, .{ .x25519_ml_kem768, .secp256r1, + .secp384r1, .x25519, })) ++ tls.extension(.psk_key_exchange_modes, array(u8, tls.PskKeyExchangeMode, .{ .psk_dhe_ke, @@ -233,6 +234,8 @@ pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client array(u16, u8, key_share.ml_kem768_kp.public_key.toBytes() ++ key_share.x25519_kp.public_key) ++ int(u16, @intFromEnum(tls.NamedGroup.secp256r1)) ++ array(u16, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++ + int(u16, @intFromEnum(tls.NamedGroup.secp384r1)) ++ + array(u16, u8, key_share.secp384r1_kp.public_key.toUncompressedSec1()) ++ int(u16, @intFromEnum(tls.NamedGroup.x25519)) ++ array(u16, u8, key_share.x25519_kp.public_key), )); @@ -1630,23 +1633,26 @@ inline fn big(x: anytype) @TypeOf(x) { } const KeyShare = struct { - x25519_kp: crypto.dh.X25519.KeyPair, - secp256r1_kp: crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair, ml_kem768_kp: crypto.kem.ml_kem.MLKem768.KeyPair, + secp256r1_kp: crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair, + secp384r1_kp: crypto.sign.ecdsa.EcdsaP384Sha384.KeyPair, + x25519_kp: crypto.dh.X25519.KeyPair, sk_buf: [sk_max_len]u8, sk_len: std.math.IntFittingRange(0, sk_max_len), const sk_max_len = @max( crypto.dh.X25519.shared_length + crypto.kem.ml_kem.MLKem768.shared_length, - crypto.dh.X25519.shared_length, crypto.ecc.P256.scalar.encoded_length, + crypto.ecc.P384.scalar.encoded_length, + crypto.dh.X25519.shared_length, ); - fn init(seed: [64]u8) error{IdentityElement}!KeyShare { + fn init(seed: [112]u8) error{IdentityElement}!KeyShare { return .{ - .x25519_kp = try .create(seed[0..32].*), - .secp256r1_kp = try .create(seed[32..64].*), .ml_kem768_kp = try .create(null), + .secp256r1_kp = try .create(seed[0..32].*), + .secp384r1_kp = try .create(seed[32..80].*), + .x25519_kp = try .create(seed[80..112].*), .sk_buf = undefined, .sk_len = 0, }; @@ -1680,6 +1686,15 @@ const KeyShare = struct { @memcpy(ks.sk_buf[0..sk.len], &sk); ks.sk_len = sk.len; }, + .secp384r1 => { + const PublicKey = crypto.sign.ecdsa.EcdsaP384Sha384.PublicKey; + const pk = PublicKey.fromSec1(server_pub_key) catch return error.TlsDecryptFailure; + const mul = pk.p.mulPublic(ks.secp384r1_kp.secret_key.bytes, .big) catch + return error.TlsDecryptFailure; + const sk = mul.affineCoordinates().x.toBytes(.big); + @memcpy(ks.sk_buf[0..sk.len], &sk); + ks.sk_len = sk.len; + }, .x25519 => { const ksl = crypto.dh.X25519.public_length; if (server_pub_key.len != ksl) return error.TlsIllegalParameter; From 9373abf7f77c37094f9ba6ca68287d8a06ebafa0 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Thu, 7 Nov 2024 20:25:04 -0500 Subject: [PATCH 14/14] std.http.Client: change ssl key log creation permission bits This is the same mode used by openssh for private keys. This does not change the mode of an existing file, so users who need something different can pre-create the file with their designed permissions or change them after the fact, and running another process that writes to the key log will not change it back. --- lib/std/http/Client.zig | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index cddc6297c9..9dcf7b5693 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -1361,7 +1361,13 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec error.OutOfMemory => return error.OutOfMemory, }; defer client.allocator.free(ssl_key_log_path); - break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ .truncate = false }) catch null; + break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ + .truncate = false, + .mode = switch (builtin.os.tag) { + .windows, .wasi => 0, + else => 0o600, + }, + }) catch null; } else null; errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close();