diff --git a/lib/std/crypto/25519/ed25519.zig b/lib/std/crypto/25519/ed25519.zig index d7b51271d2..3e5fb8bb3c 100644 --- a/lib/std/crypto/25519/ed25519.zig +++ b/lib/std/crypto/25519/ed25519.zig @@ -151,7 +151,9 @@ pub const Ed25519 = struct { a: Curve, expected_r: Curve, - fn init(sig: Signature, public_key: PublicKey) (NonCanonicalError || EncodingError || IdentityElementError)!Verifier { + pub const InitError = NonCanonicalError || EncodingError || IdentityElementError; + + fn init(sig: Signature, public_key: PublicKey) InitError!Verifier { const r = sig.r; const s = sig.s; try Curve.scalar.rejectNonCanonical(s); @@ -173,8 +175,11 @@ pub const Ed25519 = struct { self.h.update(msg); } + pub const VerifyError = WeakPublicKeyError || IdentityElementError || + SignatureVerificationError; + /// Verify that the signature is valid for the entire message. - pub fn verify(self: *Verifier) (SignatureVerificationError || WeakPublicKeyError || IdentityElementError)!void { + pub fn verify(self: *Verifier) VerifyError!void { var hram64: [Sha512.digest_length]u8 = undefined; self.h.final(&hram64); const hram = Curve.scalar.reduce64(hram64); @@ -197,10 +202,10 @@ pub const Ed25519 = struct { s: CompressedScalar, /// Return the raw signature (r, s) in little-endian format. - pub fn toBytes(self: Signature) [encoded_length]u8 { + pub fn toBytes(sig: Signature) [encoded_length]u8 { var bytes: [encoded_length]u8 = undefined; - bytes[0..Curve.encoded_length].* = self.r; - bytes[Curve.encoded_length..].* = self.s; + bytes[0..Curve.encoded_length].* = sig.r; + bytes[Curve.encoded_length..].* = sig.s; return bytes; } @@ -214,17 +219,19 @@ pub const Ed25519 = struct { } /// Create a Verifier for incremental verification of a signature. - pub fn verifier(self: Signature, public_key: PublicKey) (NonCanonicalError || EncodingError || IdentityElementError)!Verifier { - return Verifier.init(self, public_key); + pub fn verifier(sig: Signature, public_key: PublicKey) Verifier.InitError!Verifier { + return Verifier.init(sig, public_key); } + pub const VerifyError = Verifier.InitError || Verifier.VerifyError; + /// Verify the signature against a message and public key. /// Return IdentityElement or NonCanonical if the public key or signature are not in the expected range, /// or SignatureVerificationError if the signature is invalid for the given message and key. - pub fn verify(self: Signature, msg: []const u8, public_key: PublicKey) (IdentityElementError || NonCanonicalError || SignatureVerificationError || EncodingError || WeakPublicKeyError)!void { - var st = try Verifier.init(self, public_key); + pub fn verify(sig: Signature, msg: []const u8, public_key: PublicKey) VerifyError!void { + var st = try sig.verifier(public_key); st.update(msg); - return st.verify(); + try st.verify(); } }; diff --git a/lib/std/crypto/Certificate.zig b/lib/std/crypto/Certificate.zig index 3580d11fcd..9ca6aa5e48 100644 --- a/lib/std/crypto/Certificate.zig +++ b/lib/std/crypto/Certificate.zig @@ -20,18 +20,18 @@ pub const Algorithm = enum { curveEd25519, pub const map = std.StaticStringMap(Algorithm).initComptime(.{ - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01 }, .ecdsa_with_SHA224 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02 }, .ecdsa_with_SHA256 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03 }, .ecdsa_with_SHA384 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04 }, .ecdsa_with_SHA512 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x02 }, .md2WithRSAEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x04 }, .md5WithRSAEncryption }, - .{ &[_]u8{ 0x2B, 0x65, 0x70 }, .curveEd25519 }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x05 }, .sha1WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0B }, .sha256WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0C }, .sha384WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0D }, .sha512WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0E }, .sha224WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x01 }, .ecdsa_with_SHA224 }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02 }, .ecdsa_with_SHA256 }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x03 }, .ecdsa_with_SHA384 }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x04 }, .ecdsa_with_SHA512 }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x02 }, .md2WithRSAEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x04 }, .md5WithRSAEncryption }, + .{ &.{ 0x2B, 0x65, 0x70 }, .curveEd25519 }, }); pub fn Hash(comptime algorithm: Algorithm) type { @@ -49,13 +49,15 @@ pub const Algorithm = enum { pub const AlgorithmCategory = enum { rsaEncryption, + rsassa_pss, X9_62_id_ecPublicKey, curveEd25519, pub const map = std.StaticStringMap(AlgorithmCategory).initComptime(.{ - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey }, - .{ &[_]u8{ 0x2B, 0x65, 0x70 }, .curveEd25519 }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x01 }, .rsaEncryption }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x0A }, .rsassa_pss }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01 }, .X9_62_id_ecPublicKey }, + .{ &.{ 0x2B, 0x65, 0x70 }, .curveEd25519 }, }); }; @@ -74,18 +76,18 @@ pub const Attribute = enum { domainComponent, pub const map = std.StaticStringMap(Attribute).initComptime(.{ - .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, - .{ &[_]u8{ 0x55, 0x04, 0x05 }, .serialNumber }, - .{ &[_]u8{ 0x55, 0x04, 0x06 }, .countryName }, - .{ &[_]u8{ 0x55, 0x04, 0x07 }, .localityName }, - .{ &[_]u8{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, - .{ &[_]u8{ 0x55, 0x04, 0x09 }, .streetAddress }, - .{ &[_]u8{ 0x55, 0x04, 0x0A }, .organizationName }, - .{ &[_]u8{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, - .{ &[_]u8{ 0x55, 0x04, 0x11 }, .postalCode }, - .{ &[_]u8{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x09, 0x01 }, .pkcs9_emailAddress }, - .{ &[_]u8{ 0x09, 0x92, 0x26, 0x89, 0x93, 0xF2, 0x2C, 0x64, 0x01, 0x19 }, .domainComponent }, + .{ &.{ 0x55, 0x04, 0x03 }, .commonName }, + .{ &.{ 0x55, 0x04, 0x05 }, .serialNumber }, + .{ &.{ 0x55, 0x04, 0x06 }, .countryName }, + .{ &.{ 0x55, 0x04, 0x07 }, .localityName }, + .{ &.{ 0x55, 0x04, 0x08 }, .stateOrProvinceName }, + .{ &.{ 0x55, 0x04, 0x09 }, .streetAddress }, + .{ &.{ 0x55, 0x04, 0x0A }, .organizationName }, + .{ &.{ 0x55, 0x04, 0x0B }, .organizationalUnitName }, + .{ &.{ 0x55, 0x04, 0x11 }, .postalCode }, + .{ &.{ 0x55, 0x04, 0x61 }, .organizationIdentifier }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x09, 0x01 }, .pkcs9_emailAddress }, + .{ &.{ 0x09, 0x92, 0x26, 0x89, 0x93, 0xF2, 0x2C, 0x64, 0x01, 0x19 }, .domainComponent }, }); }; @@ -95,9 +97,9 @@ pub const NamedCurve = enum { X9_62_prime256v1, pub const map = std.StaticStringMap(NamedCurve).initComptime(.{ - .{ &[_]u8{ 0x2B, 0x81, 0x04, 0x00, 0x22 }, .secp384r1 }, - .{ &[_]u8{ 0x2B, 0x81, 0x04, 0x00, 0x23 }, .secp521r1 }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07 }, .X9_62_prime256v1 }, + .{ &.{ 0x2B, 0x81, 0x04, 0x00, 0x22 }, .secp384r1 }, + .{ &.{ 0x2B, 0x81, 0x04, 0x00, 0x23 }, .secp521r1 }, + .{ &.{ 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07 }, .X9_62_prime256v1 }, }); pub fn Curve(comptime curve: NamedCurve) type { @@ -131,28 +133,28 @@ pub const ExtensionId = enum { netscape_comment, pub const map = std.StaticStringMap(ExtensionId).initComptime(.{ - .{ &[_]u8{ 0x55, 0x04, 0x03 }, .commonName }, - .{ &[_]u8{ 0x55, 0x1D, 0x01 }, .authority_key_identifier }, - .{ &[_]u8{ 0x55, 0x1D, 0x07 }, .subject_alt_name }, - .{ &[_]u8{ 0x55, 0x1D, 0x0E }, .subject_key_identifier }, - .{ &[_]u8{ 0x55, 0x1D, 0x0F }, .key_usage }, - .{ &[_]u8{ 0x55, 0x1D, 0x0A }, .basic_constraints }, - .{ &[_]u8{ 0x55, 0x1D, 0x10 }, .private_key_usage_period }, - .{ &[_]u8{ 0x55, 0x1D, 0x11 }, .subject_alt_name }, - .{ &[_]u8{ 0x55, 0x1D, 0x12 }, .issuer_alt_name }, - .{ &[_]u8{ 0x55, 0x1D, 0x13 }, .basic_constraints }, - .{ &[_]u8{ 0x55, 0x1D, 0x14 }, .crl_number }, - .{ &[_]u8{ 0x55, 0x1D, 0x1F }, .crl_distribution_points }, - .{ &[_]u8{ 0x55, 0x1D, 0x20 }, .certificate_policies }, - .{ &[_]u8{ 0x55, 0x1D, 0x23 }, .authority_key_identifier }, - .{ &[_]u8{ 0x55, 0x1D, 0x25 }, .ext_key_usage }, - .{ &[_]u8{ 0x2B, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x01 }, .msCertsrvCAVersion }, - .{ &[_]u8{ 0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x01 }, .info_access }, - .{ &[_]u8{ 0x2A, 0x86, 0x48, 0x86, 0xF6, 0x7D, 0x07, 0x41, 0x00 }, .entrustVersInfo }, - .{ &[_]u8{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x14, 0x02 }, .enroll_certtype }, - .{ &[_]u8{ 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x0c }, .pe_logotype }, - .{ &[_]u8{ 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x01 }, .netscape_cert_type }, - .{ &[_]u8{ 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x0d }, .netscape_comment }, + .{ &.{ 0x55, 0x04, 0x03 }, .commonName }, + .{ &.{ 0x55, 0x1D, 0x01 }, .authority_key_identifier }, + .{ &.{ 0x55, 0x1D, 0x07 }, .subject_alt_name }, + .{ &.{ 0x55, 0x1D, 0x0E }, .subject_key_identifier }, + .{ &.{ 0x55, 0x1D, 0x0F }, .key_usage }, + .{ &.{ 0x55, 0x1D, 0x0A }, .basic_constraints }, + .{ &.{ 0x55, 0x1D, 0x10 }, .private_key_usage_period }, + .{ &.{ 0x55, 0x1D, 0x11 }, .subject_alt_name }, + .{ &.{ 0x55, 0x1D, 0x12 }, .issuer_alt_name }, + .{ &.{ 0x55, 0x1D, 0x13 }, .basic_constraints }, + .{ &.{ 0x55, 0x1D, 0x14 }, .crl_number }, + .{ &.{ 0x55, 0x1D, 0x1F }, .crl_distribution_points }, + .{ &.{ 0x55, 0x1D, 0x20 }, .certificate_policies }, + .{ &.{ 0x55, 0x1D, 0x23 }, .authority_key_identifier }, + .{ &.{ 0x55, 0x1D, 0x25 }, .ext_key_usage }, + .{ &.{ 0x2B, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x15, 0x01 }, .msCertsrvCAVersion }, + .{ &.{ 0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x01 }, .info_access }, + .{ &.{ 0x2A, 0x86, 0x48, 0x86, 0xF6, 0x7D, 0x07, 0x41, 0x00 }, .entrustVersInfo }, + .{ &.{ 0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x14, 0x02 }, .enroll_certtype }, + .{ &.{ 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x0c }, .pe_logotype }, + .{ &.{ 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x01 }, .netscape_cert_type }, + .{ &.{ 0x60, 0x86, 0x48, 0x01, 0x86, 0xf8, 0x42, 0x01, 0x0d }, .netscape_comment }, }); }; @@ -185,6 +187,7 @@ pub const Parsed = struct { pub const PubKeyAlgo = union(AlgorithmCategory) { rsaEncryption: void, + rsassa_pss: void, X9_62_id_ecPublicKey: NamedCurve, curveEd25519: void, }; @@ -386,7 +389,7 @@ test "Parsed.checkHostName" { try expectEqual(true, Parsed.checkHostName("bar.ziglang.org", "*.Ziglang.ORG")); } -pub const ParseError = der.Element.ParseElementError || ParseVersionError || ParseTimeError || ParseEnumError || ParseBitStringError; +pub const ParseError = der.Element.ParseError || ParseVersionError || ParseTimeError || ParseEnumError || ParseBitStringError; pub fn parse(cert: Certificate) ParseError!Parsed { const cert_bytes = cert.buffer; @@ -413,13 +416,9 @@ pub fn parse(cert: Certificate) ParseError!Parsed { const pub_key_info = try der.Element.parse(cert_bytes, subject.slice.end); const pub_key_signature_algorithm = try der.Element.parse(cert_bytes, pub_key_info.slice.start); const pub_key_algo_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.start); - const pub_key_algo_tag = try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem); - var pub_key_algo: Parsed.PubKeyAlgo = undefined; - switch (pub_key_algo_tag) { - .rsaEncryption => { - pub_key_algo = .{ .rsaEncryption = {} }; - }, - .X9_62_id_ecPublicKey => { + const pub_key_algo: Parsed.PubKeyAlgo = switch (try parseAlgorithmCategory(cert_bytes, pub_key_algo_elem)) { + inline else => |tag| @unionInit(Parsed.PubKeyAlgo, @tagName(tag), {}), + .X9_62_id_ecPublicKey => pub_key_algo: { // RFC 5480 Section 2.1.1.1 Named Curve // ECParameters ::= CHOICE { // namedCurve OBJECT IDENTIFIER @@ -428,12 +427,9 @@ pub fn parse(cert: Certificate) ParseError!Parsed { // } const params_elem = try der.Element.parse(cert_bytes, pub_key_algo_elem.slice.end); const named_curve = try parseNamedCurve(cert_bytes, params_elem); - pub_key_algo = .{ .X9_62_id_ecPublicKey = named_curve }; + break :pub_key_algo .{ .X9_62_id_ecPublicKey = named_curve }; }, - .curveEd25519 => { - pub_key_algo = .{ .curveEd25519 = {} }; - }, - } + }; const pub_key_elem = try der.Element.parse(cert_bytes, pub_key_signature_algorithm.slice.end); const pub_key = try parseBitString(cert, pub_key_elem); @@ -731,7 +727,7 @@ pub fn parseVersion(bytes: []const u8, version_elem: der.Element) ParseVersionEr fn verifyRsa( comptime Hash: type, - message: []const u8, + msg: []const u8, sig: []const u8, pub_key_algo: Parsed.PubKeyAlgo, pub_key: []const u8, @@ -743,59 +739,14 @@ fn verifyRsa( if (exponent.len > modulus.len) return error.CertificatePublicKeyInvalid; if (sig.len != modulus.len) return error.CertificateSignatureInvalidLength; - const hash_der = switch (Hash) { - crypto.hash.Sha1 => [_]u8{ - 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, - 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, - }, - crypto.hash.sha2.Sha224 => [_]u8{ - 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, - 0x00, 0x04, 0x1c, - }, - crypto.hash.sha2.Sha256 => [_]u8{ - 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, - 0x00, 0x04, 0x20, - }, - crypto.hash.sha2.Sha384 => [_]u8{ - 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, - 0x00, 0x04, 0x30, - }, - crypto.hash.sha2.Sha512 => [_]u8{ - 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, - 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, - 0x00, 0x04, 0x40, - }, - else => @compileError("unreachable"), - }; - - var msg_hashed: [Hash.digest_length]u8 = undefined; - Hash.hash(message, &msg_hashed, .{}); - switch (modulus.len) { inline 128, 256, 384, 512 => |modulus_len| { - const ps_len = modulus_len - (hash_der.len + msg_hashed.len) - 3; - const em: [modulus_len]u8 = - [2]u8{ 0, 1 } ++ - ([1]u8{0xff} ** ps_len) ++ - [1]u8{0} ++ - hash_der ++ - msg_hashed; - - const public_key = rsa.PublicKey.fromBytes(exponent, modulus) catch return error.CertificateSignatureInvalid; - const em_dec = rsa.encrypt(modulus_len, sig[0..modulus_len].*, public_key) catch |err| switch (err) { - error.MessageTooLong => unreachable, - }; - - if (!mem.eql(u8, &em, &em_dec)) { + const public_key = rsa.PublicKey.fromBytes(exponent, modulus) catch + return error.CertificateSignatureInvalid; + rsa.PKCS1v1_5Signature.verify(modulus_len, sig[0..modulus_len].*, msg, public_key, Hash) catch return error.CertificateSignatureInvalid; - } - }, - else => { - return error.CertificateSignatureUnsupportedBitCount; }, + else => return error.CertificateSignatureUnsupportedBitCount, } } @@ -908,9 +859,9 @@ pub const der = struct { pub const empty: Slice = .{ .start = 0, .end = 0 }; }; - pub const ParseElementError = error{CertificateFieldHasInvalidLength}; + pub const ParseError = error{CertificateFieldHasInvalidLength}; - pub fn parse(bytes: []const u8, index: u32) ParseElementError!Element { + pub fn parse(bytes: []const u8, index: u32) Element.ParseError!Element { var i = index; const identifier = @as(Identifier, @bitCast(bytes[i])); i += 1; @@ -958,21 +909,41 @@ pub const rsa = struct { const Modulus = std.crypto.ff.Modulus(max_modulus_bits); const Fe = Modulus.Fe; + /// RFC 3447 8.1 RSASSA-PSS pub const PSSSignature = struct { pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 { - var result = [1]u8{0} ** modulus_len; - std.mem.copyForwards(u8, &result, msg); + var result: [modulus_len]u8 = undefined; + @memcpy(result[0..msg.len], msg); + @memset(result[msg.len..], 0); return result; } - pub fn verify(comptime modulus_len: usize, sig: [modulus_len]u8, msg: []const u8, public_key: PublicKey, comptime Hash: type) !void { + pub const VerifyError = EncryptError || error{InvalidSignature}; + + pub fn verify( + comptime modulus_len: usize, + sig: [modulus_len]u8, + msg: []const u8, + public_key: PublicKey, + comptime Hash: type, + ) VerifyError!void { + try concatVerify(modulus_len, sig, &.{msg}, public_key, Hash); + } + + pub fn concatVerify( + comptime modulus_len: usize, + sig: [modulus_len]u8, + msg: []const []const u8, + public_key: PublicKey, + comptime Hash: type, + ) VerifyError!void { const mod_bits = public_key.n.bits(); const em_dec = try encrypt(modulus_len, sig, public_key); - EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash) catch unreachable; + try EMSA_PSS_VERIFY(msg, &em_dec, mod_bits - 1, Hash.digest_length, Hash); } - fn EMSA_PSS_VERIFY(msg: []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) !void { + fn EMSA_PSS_VERIFY(msg: []const []const u8, em: []const u8, emBit: usize, sLen: usize, comptime Hash: type) VerifyError!void { // 1. If the length of M is greater than the input limitation for // the hash function (2^61 - 1 octets for SHA-1), output // "inconsistent" and stop. @@ -986,7 +957,11 @@ pub const rsa = struct { // 2. Let mHash = Hash(M), an octet string of length hLen. var mHash: [Hash.digest_length]u8 = undefined; - Hash.hash(msg, &mHash, .{}); + { + var hasher: Hash = .init(.{}); + for (msg) |part| hasher.update(part); + hasher.final(&mHash); + } // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop. if (emLen < Hash.digest_length + sLen + 2) { @@ -1082,25 +1057,14 @@ pub const rsa = struct { } fn MGF1(comptime Hash: type, out: []u8, seed: *const [Hash.digest_length]u8, len: usize) ![]u8 { - var counter: usize = 0; + var counter: u32 = 0; var idx: usize = 0; - var c: [4]u8 = undefined; - var hash: [Hash.digest_length + c.len]u8 = undefined; - @memcpy(hash[0..Hash.digest_length], seed); - var hashed: [Hash.digest_length]u8 = undefined; + var hash = seed.* ++ @as([4]u8, undefined); while (idx < len) { - c[0] = @as(u8, @intCast((counter >> 24) & 0xFF)); - c[1] = @as(u8, @intCast((counter >> 16) & 0xFF)); - c[2] = @as(u8, @intCast((counter >> 8) & 0xFF)); - c[3] = @as(u8, @intCast(counter & 0xFF)); - - std.mem.copyForwards(u8, hash[seed.len..], &c); - Hash.hash(&hash, &hashed, .{}); - - std.mem.copyForwards(u8, out[idx..], &hashed); - idx += hashed.len; - + std.mem.writeInt(u32, hash[seed.len..][0..4], counter, .big); + Hash.hash(&hash, out[idx..][0..Hash.digest_length], .{}); + idx += Hash.digest_length; counter += 1; } @@ -1108,11 +1072,128 @@ pub const rsa = struct { } }; + /// RFC 3447 8.2 RSASSA-PKCS1-v1_5 + pub const PKCS1v1_5Signature = struct { + pub fn fromBytes(comptime modulus_len: usize, msg: []const u8) [modulus_len]u8 { + var result: [modulus_len]u8 = undefined; + @memcpy(result[0..msg.len], msg); + @memset(result[msg.len..], 0); + return result; + } + + pub const VerifyError = EncryptError || error{InvalidSignature}; + + pub fn verify( + comptime modulus_len: usize, + sig: [modulus_len]u8, + msg: []const u8, + public_key: PublicKey, + comptime Hash: type, + ) VerifyError!void { + try concatVerify(modulus_len, sig, &.{msg}, public_key, Hash); + } + + pub fn concatVerify( + comptime modulus_len: usize, + sig: [modulus_len]u8, + msg: []const []const u8, + public_key: PublicKey, + comptime Hash: type, + ) VerifyError!void { + const em_dec = try encrypt(modulus_len, sig, public_key); + const em = try EMSA_PKCS1_V1_5_ENCODE(msg, modulus_len, Hash); + if (!std.mem.eql(u8, &em_dec, &em)) return error.InvalidSignature; + } + + fn EMSA_PKCS1_V1_5_ENCODE(msg: []const []const u8, comptime emLen: usize, comptime Hash: type) VerifyError![emLen]u8 { + comptime var em_index = emLen; + var em: [emLen]u8 = undefined; + + // 1. Apply the hash function to the message M to produce a hash value + // H: + // + // H = Hash(M). + // + // If the hash function outputs "message too long," output "message + // too long" and stop. + var hasher: Hash = .init(.{}); + for (msg) |part| hasher.update(part); + em_index -= Hash.digest_length; + hasher.final(em[em_index..]); + + // 2. Encode the algorithm ID for the hash function and the hash value + // into an ASN.1 value of type DigestInfo (see Appendix A.2.4) with + // the Distinguished Encoding Rules (DER), where the type DigestInfo + // has the syntax + // + // DigestInfo ::= SEQUENCE { + // digestAlgorithm AlgorithmIdentifier, + // digest OCTET STRING + // } + // + // The first field identifies the hash function and the second + // contains the hash value. Let T be the DER encoding of the + // DigestInfo value (see the notes below) and let tLen be the length + // in octets of T. + const hash_der: []const u8 = &switch (Hash) { + crypto.hash.Sha1 => .{ + 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, + 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14, + }, + crypto.hash.sha2.Sha224 => .{ + 0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, + 0x00, 0x04, 0x1c, + }, + crypto.hash.sha2.Sha256 => .{ + 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, + 0x00, 0x04, 0x20, + }, + crypto.hash.sha2.Sha384 => .{ + 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, + 0x00, 0x04, 0x30, + }, + crypto.hash.sha2.Sha512 => .{ + 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, + 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, + 0x00, 0x04, 0x40, + }, + else => @compileError("unreachable"), + }; + em_index -= hash_der.len; + @memcpy(em[em_index..][0..hash_der.len], hash_der); + + // 3. If emLen < tLen + 11, output "intended encoded message length too + // short" and stop. + + // 4. Generate an octet string PS consisting of emLen - tLen - 3 octets + // with hexadecimal value 0xff. The length of PS will be at least 8 + // octets. + em_index -= 1; + @memset(em[2..em_index], 0xff); + + // 5. Concatenate PS, the DER encoding T, and other padding to form the + // encoded message EM as + // + // EM = 0x00 || 0x01 || PS || 0x00 || T. + em[em_index] = 0x00; + em[1] = 0x01; + em[0] = 0x00; + + // 6. Output EM. + return em; + } + }; + pub const PublicKey = struct { n: Modulus, e: Fe, - pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8) !PublicKey { + pub const FromBytesError = error{CertificatePublicKeyInvalid}; + + pub fn fromBytes(pub_bytes: []const u8, modulus_bytes: []const u8) FromBytesError!PublicKey { // Reject modulus below 512 bits. // 512-bit RSA was factored in 1999, so this limit barely means anything, // but establish some limit now to ratchet in what we can. @@ -1137,7 +1218,9 @@ pub const rsa = struct { }; } - pub fn parseDer(pub_key: []const u8) !struct { modulus: []const u8, exponent: []const u8 } { + pub const ParseDerError = der.Element.ParseError || error{CertificateFieldHasWrongDataType}; + + pub fn parseDer(pub_key: []const u8) ParseDerError!struct { modulus: []const u8, exponent: []const u8 } { const pub_key_seq = try der.Element.parse(pub_key, 0); if (pub_key_seq.identifier.tag != .sequence) return error.CertificateFieldHasWrongDataType; const modulus_elem = try der.Element.parse(pub_key, pub_key_seq.slice.start); @@ -1156,7 +1239,9 @@ pub const rsa = struct { } }; - fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey) ![modulus_len]u8 { + const EncryptError = error{MessageTooLong}; + + fn encrypt(comptime modulus_len: usize, msg: [modulus_len]u8, public_key: PublicKey) EncryptError![modulus_len]u8 { const m = Fe.fromBytes(public_key.n, &msg, .big) catch return error.MessageTooLong; const e = public_key.n.powPublic(m, public_key.e) catch unreachable; var res: [modulus_len]u8 = undefined; diff --git a/lib/std/crypto/ecdsa.zig b/lib/std/crypto/ecdsa.zig index 50f13a010d..649c967218 100644 --- a/lib/std/crypto/ecdsa.zig +++ b/lib/std/crypto/ecdsa.zig @@ -91,24 +91,26 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { s: Curve.scalar.CompressedScalar, /// Create a Verifier for incremental verification of a signature. - pub fn verifier(self: Signature, public_key: PublicKey) (NonCanonicalError || EncodingError || IdentityElementError)!Verifier { - return Verifier.init(self, public_key); + pub fn verifier(sig: Signature, public_key: PublicKey) Verifier.InitError!Verifier { + return Verifier.init(sig, public_key); } + pub const VerifyError = Verifier.InitError || Verifier.VerifyError; + /// Verify the signature against a message and public key. /// Return IdentityElement or NonCanonical if the public key or signature are not in the expected range, /// or SignatureVerificationError if the signature is invalid for the given message and key. - pub fn verify(self: Signature, msg: []const u8, public_key: PublicKey) (IdentityElementError || NonCanonicalError || SignatureVerificationError)!void { - var st = try Verifier.init(self, public_key); + pub fn verify(sig: Signature, msg: []const u8, public_key: PublicKey) VerifyError!void { + var st = try sig.verifier(public_key); st.update(msg); - return st.verify(); + try st.verify(); } /// Return the raw signature (r, s) in big-endian format. - pub fn toBytes(self: Signature) [encoded_length]u8 { + pub fn toBytes(sig: Signature) [encoded_length]u8 { var bytes: [encoded_length]u8 = undefined; - @memcpy(bytes[0 .. encoded_length / 2], &self.r); - @memcpy(bytes[encoded_length / 2 ..], &self.s); + @memcpy(bytes[0 .. encoded_length / 2], &sig.r); + @memcpy(bytes[encoded_length / 2 ..], &sig.s); return bytes; } @@ -124,23 +126,23 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { /// Encode the signature using the DER format. /// The maximum length of the DER encoding is der_encoded_length_max. /// The function returns a slice, that can be shorter than der_encoded_length_max. - pub fn toDer(self: Signature, buf: *[der_encoded_length_max]u8) []u8 { + pub fn toDer(sig: Signature, buf: *[der_encoded_length_max]u8) []u8 { var fb = io.fixedBufferStream(buf); const w = fb.writer(); - const r_len = @as(u8, @intCast(self.r.len + (self.r[0] >> 7))); - const s_len = @as(u8, @intCast(self.s.len + (self.s[0] >> 7))); + const r_len = @as(u8, @intCast(sig.r.len + (sig.r[0] >> 7))); + const s_len = @as(u8, @intCast(sig.s.len + (sig.s[0] >> 7))); const seq_len = @as(u8, @intCast(2 + r_len + 2 + s_len)); w.writeAll(&[_]u8{ 0x30, seq_len }) catch unreachable; w.writeAll(&[_]u8{ 0x02, r_len }) catch unreachable; - if (self.r[0] >> 7 != 0) { + if (sig.r[0] >> 7 != 0) { w.writeByte(0x00) catch unreachable; } - w.writeAll(&self.r) catch unreachable; + w.writeAll(&sig.r) catch unreachable; w.writeAll(&[_]u8{ 0x02, s_len }) catch unreachable; - if (self.s[0] >> 7 != 0) { + if (sig.s[0] >> 7 != 0) { w.writeByte(0x00) catch unreachable; } - w.writeAll(&self.s) catch unreachable; + w.writeAll(&sig.s) catch unreachable; return fb.getWritten(); } @@ -236,7 +238,9 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { s: Curve.scalar.Scalar, public_key: PublicKey, - fn init(sig: Signature, public_key: PublicKey) (IdentityElementError || NonCanonicalError)!Verifier { + pub const InitError = IdentityElementError || NonCanonicalError; + + fn init(sig: Signature, public_key: PublicKey) InitError!Verifier { const r = try Curve.scalar.Scalar.fromBytes(sig.r, .big); const s = try Curve.scalar.Scalar.fromBytes(sig.s, .big); if (r.isZero() or s.isZero()) return error.IdentityElement; @@ -254,8 +258,11 @@ pub fn Ecdsa(comptime Curve: type, comptime Hash: type) type { self.h.update(data); } + pub const VerifyError = IdentityElementError || NonCanonicalError || + SignatureVerificationError; + /// Verify that the signature is valid for the entire message. - pub fn verify(self: *Verifier) (IdentityElementError || NonCanonicalError || SignatureVerificationError)!void { + pub fn verify(self: *Verifier) VerifyError!void { const ht = Curve.scalar.encoded_length; const h_len = @max(Hash.digest_length, ht); var h: [h_len]u8 = [_]u8{0} ** h_len; diff --git a/lib/std/crypto/tls.zig b/lib/std/crypto/tls.zig index fb1b550e42..74113225cb 100644 --- a/lib/std/crypto/tls.zig +++ b/lib/std/crypto/tls.zig @@ -54,6 +54,8 @@ pub const close_notify_alert = [_]u8{ }; pub const ProtocolVersion = enum(u16) { + tls_1_0 = 0x0301, + tls_1_1 = 0x0302, tls_1_2 = 0x0303, tls_1_3 = 0x0304, _, @@ -69,14 +71,18 @@ pub const ContentType = enum(u8) { }; pub const HandshakeType = enum(u8) { + hello_request = 0, client_hello = 1, server_hello = 2, new_session_ticket = 4, end_of_early_data = 5, encrypted_extensions = 8, certificate = 11, + server_key_exchange = 12, certificate_request = 13, + server_hello_done = 14, certificate_verify = 15, + client_key_exchange = 16, finished = 20, key_update = 24, message_hash = 254, @@ -198,36 +204,36 @@ pub const AlertDescription = enum(u8) { _, pub fn toError(alert: AlertDescription) Error!void { - return switch (alert) { + switch (alert) { .close_notify => {}, // not an error - .unexpected_message => error.TlsAlertUnexpectedMessage, - .bad_record_mac => error.TlsAlertBadRecordMac, - .record_overflow => error.TlsAlertRecordOverflow, - .handshake_failure => error.TlsAlertHandshakeFailure, - .bad_certificate => error.TlsAlertBadCertificate, - .unsupported_certificate => error.TlsAlertUnsupportedCertificate, - .certificate_revoked => error.TlsAlertCertificateRevoked, - .certificate_expired => error.TlsAlertCertificateExpired, - .certificate_unknown => error.TlsAlertCertificateUnknown, - .illegal_parameter => error.TlsAlertIllegalParameter, - .unknown_ca => error.TlsAlertUnknownCa, - .access_denied => error.TlsAlertAccessDenied, - .decode_error => error.TlsAlertDecodeError, - .decrypt_error => error.TlsAlertDecryptError, - .protocol_version => error.TlsAlertProtocolVersion, - .insufficient_security => error.TlsAlertInsufficientSecurity, - .internal_error => error.TlsAlertInternalError, - .inappropriate_fallback => error.TlsAlertInappropriateFallback, + .unexpected_message => return error.TlsAlertUnexpectedMessage, + .bad_record_mac => return error.TlsAlertBadRecordMac, + .record_overflow => return error.TlsAlertRecordOverflow, + .handshake_failure => return error.TlsAlertHandshakeFailure, + .bad_certificate => return error.TlsAlertBadCertificate, + .unsupported_certificate => return error.TlsAlertUnsupportedCertificate, + .certificate_revoked => return error.TlsAlertCertificateRevoked, + .certificate_expired => return error.TlsAlertCertificateExpired, + .certificate_unknown => return error.TlsAlertCertificateUnknown, + .illegal_parameter => return error.TlsAlertIllegalParameter, + .unknown_ca => return error.TlsAlertUnknownCa, + .access_denied => return error.TlsAlertAccessDenied, + .decode_error => return error.TlsAlertDecodeError, + .decrypt_error => return error.TlsAlertDecryptError, + .protocol_version => return error.TlsAlertProtocolVersion, + .insufficient_security => return error.TlsAlertInsufficientSecurity, + .internal_error => return error.TlsAlertInternalError, + .inappropriate_fallback => return error.TlsAlertInappropriateFallback, .user_canceled => {}, // not an error - .missing_extension => error.TlsAlertMissingExtension, - .unsupported_extension => error.TlsAlertUnsupportedExtension, - .unrecognized_name => error.TlsAlertUnrecognizedName, - .bad_certificate_status_response => error.TlsAlertBadCertificateStatusResponse, - .unknown_psk_identity => error.TlsAlertUnknownPskIdentity, - .certificate_required => error.TlsAlertCertificateRequired, - .no_application_protocol => error.TlsAlertNoApplicationProtocol, - _ => error.TlsAlertUnknown, - }; + .missing_extension => return error.TlsAlertMissingExtension, + .unsupported_extension => return error.TlsAlertUnsupportedExtension, + .unrecognized_name => return error.TlsAlertUnrecognizedName, + .bad_certificate_status_response => return error.TlsAlertBadCertificateStatusResponse, + .unknown_psk_identity => return error.TlsAlertUnknownPskIdentity, + .certificate_required => return error.TlsAlertCertificateRequired, + .no_application_protocol => return error.TlsAlertNoApplicationProtocol, + _ => return error.TlsAlertUnknown, + } } }; @@ -260,6 +266,17 @@ pub const SignatureScheme = enum(u16) { rsa_pkcs1_sha1 = 0x0201, ecdsa_sha1 = 0x0203, + ecdsa_brainpoolP256r1tls13_sha256 = 0x081a, + ecdsa_brainpoolP384r1tls13_sha384 = 0x081b, + ecdsa_brainpoolP512r1tls13_sha512 = 0x081c, + + rsa_sha224 = 0x0301, + dsa_sha224 = 0x0302, + ecdsa_sha224 = 0x0303, + dsa_sha256 = 0x0402, + dsa_sha384 = 0x0502, + dsa_sha512 = 0x0602, + _, }; @@ -285,7 +302,27 @@ pub const NamedGroup = enum(u16) { _, }; +pub const PskKeyExchangeMode = enum(u8) { + psk_ke = 0, + psk_dhe_ke = 1, + _, +}; + pub const CipherSuite = enum(u16) { + RSA_WITH_AES_128_CBC_SHA = 0x002F, + DHE_RSA_WITH_AES_128_CBC_SHA = 0x0033, + RSA_WITH_AES_256_CBC_SHA = 0x0035, + DHE_RSA_WITH_AES_256_CBC_SHA = 0x0039, + RSA_WITH_AES_128_CBC_SHA256 = 0x003C, + RSA_WITH_AES_256_CBC_SHA256 = 0x003D, + DHE_RSA_WITH_AES_128_CBC_SHA256 = 0x0067, + DHE_RSA_WITH_AES_256_CBC_SHA256 = 0x006B, + RSA_WITH_AES_128_GCM_SHA256 = 0x009C, + RSA_WITH_AES_256_GCM_SHA384 = 0x009D, + DHE_RSA_WITH_AES_128_GCM_SHA256 = 0x009E, + DHE_RSA_WITH_AES_256_GCM_SHA384 = 0x009F, + EMPTY_RENEGOTIATION_INFO_SCSV = 0x00FF, + AES_128_GCM_SHA256 = 0x1301, AES_256_GCM_SHA384 = 0x1302, CHACHA20_POLY1305_SHA256 = 0x1303, @@ -293,6 +330,102 @@ pub const CipherSuite = enum(u16) { AES_128_CCM_8_SHA256 = 0x1305, AEGIS_256_SHA512 = 0x1306, AEGIS_128L_SHA256 = 0x1307, + + ECDHE_ECDSA_WITH_AES_128_CBC_SHA = 0xC009, + ECDHE_ECDSA_WITH_AES_256_CBC_SHA = 0xC00A, + ECDHE_RSA_WITH_AES_128_CBC_SHA = 0xC013, + ECDHE_RSA_WITH_AES_256_CBC_SHA = 0xC014, + ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = 0xC023, + ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = 0xC024, + ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027, + ECDHE_RSA_WITH_AES_256_CBC_SHA384 = 0xC028, + ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = 0xC02B, + ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = 0xC02C, + ECDHE_RSA_WITH_AES_128_GCM_SHA256 = 0xC02F, + ECDHE_RSA_WITH_AES_256_GCM_SHA384 = 0xC030, + + ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA8, + ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCA9, + DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = 0xCCAA, + + _, + + pub const With = enum { + AES_128_CBC_SHA, + AES_256_CBC_SHA, + AES_128_CBC_SHA256, + AES_256_CBC_SHA256, + AES_256_CBC_SHA384, + + AES_128_GCM_SHA256, + AES_256_GCM_SHA384, + + CHACHA20_POLY1305_SHA256, + + AES_128_CCM_SHA256, + AES_128_CCM_8_SHA256, + + AEGIS_256_SHA512, + AEGIS_128L_SHA256, + }; + + pub fn with(cipher_suite: CipherSuite) With { + return switch (cipher_suite) { + .RSA_WITH_AES_128_CBC_SHA, + .DHE_RSA_WITH_AES_128_CBC_SHA, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + .ECDHE_RSA_WITH_AES_128_CBC_SHA, + => .AES_128_CBC_SHA, + .RSA_WITH_AES_256_CBC_SHA, + .DHE_RSA_WITH_AES_256_CBC_SHA, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + .ECDHE_RSA_WITH_AES_256_CBC_SHA, + => .AES_256_CBC_SHA, + .RSA_WITH_AES_128_CBC_SHA256, + .DHE_RSA_WITH_AES_128_CBC_SHA256, + .ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + .ECDHE_RSA_WITH_AES_128_CBC_SHA256, + => .AES_128_CBC_SHA256, + .RSA_WITH_AES_256_CBC_SHA256, + .DHE_RSA_WITH_AES_256_CBC_SHA256, + => .AES_256_CBC_SHA256, + .ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + .ECDHE_RSA_WITH_AES_256_CBC_SHA384, + => .AES_256_CBC_SHA384, + + .RSA_WITH_AES_128_GCM_SHA256, + .DHE_RSA_WITH_AES_128_GCM_SHA256, + .AES_128_GCM_SHA256, + .ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + => .AES_128_GCM_SHA256, + .RSA_WITH_AES_256_GCM_SHA384, + .DHE_RSA_WITH_AES_256_GCM_SHA384, + .AES_256_GCM_SHA384, + .ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + => .AES_256_GCM_SHA384, + + .CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + .ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + .DHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => .CHACHA20_POLY1305_SHA256, + + .AES_128_CCM_SHA256 => .AES_128_CCM_SHA256, + .AES_128_CCM_8_SHA256 => .AES_128_CCM_8_SHA256, + + .AEGIS_256_SHA512 => .AEGIS_256_SHA512, + .AEGIS_128L_SHA256 => .AEGIS_128L_SHA256, + + .EMPTY_RENEGOTIATION_INFO_SCSV => unreachable, + _ => unreachable, + }; + } +}; + +pub const CompressionMethod = enum(u8) { + null = 0, _, }; @@ -308,58 +441,114 @@ pub const KeyUpdateRequest = enum(u8) { _, }; -pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type) type { - return struct { - pub const AEAD = AeadType; - pub const Hash = HashType; - pub const Hmac = crypto.auth.hmac.Hmac(Hash); - pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); +pub const ChangeCipherSpecType = enum(u8) { + change_cipher_spec = 1, + _, +}; - handshake_secret: [Hkdf.prk_length]u8, - master_secret: [Hkdf.prk_length]u8, - client_handshake_key: [AEAD.key_length]u8, - server_handshake_key: [AEAD.key_length]u8, - client_finished_key: [Hmac.key_length]u8, - server_finished_key: [Hmac.key_length]u8, - client_handshake_iv: [AEAD.nonce_length]u8, - server_handshake_iv: [AEAD.nonce_length]u8, - transcript_hash: Hash, +pub fn HandshakeCipherT(comptime AeadType: type, comptime HashType: type, comptime explicit_iv_length: comptime_int) type { + return struct { + pub const A = ApplicationCipherT(AeadType, HashType, explicit_iv_length); + + transcript_hash: A.Hash, + version: union { + tls_1_2: struct { + expected_server_verify_data: [A.verify_data_length]u8, + app_cipher: A.Tls_1_2, + }, + tls_1_3: struct { + handshake_secret: [A.Hkdf.prk_length]u8, + master_secret: [A.Hkdf.prk_length]u8, + client_handshake_key: [A.AEAD.key_length]u8, + server_handshake_key: [A.AEAD.key_length]u8, + client_finished_key: [A.Hmac.key_length]u8, + server_finished_key: [A.Hmac.key_length]u8, + client_handshake_iv: [A.AEAD.nonce_length]u8, + server_handshake_iv: [A.AEAD.nonce_length]u8, + }, + }, }; } pub const HandshakeCipher = union(enum) { - AES_128_GCM_SHA256: HandshakeCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), - AES_256_GCM_SHA384: HandshakeCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), - CHACHA20_POLY1305_SHA256: HandshakeCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), - AEGIS_256_SHA512: HandshakeCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512), - AEGIS_128L_SHA256: HandshakeCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), + AES_128_GCM_SHA256: HandshakeCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256, 8), + AES_256_GCM_SHA384: HandshakeCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384, 8), + CHACHA20_POLY1305_SHA256: HandshakeCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256, 0), + AEGIS_256_SHA512: HandshakeCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512, 0), + AEGIS_128L_SHA256: HandshakeCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256, 0), }; -pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type) type { - return struct { +pub fn ApplicationCipherT(comptime AeadType: type, comptime HashType: type, comptime explicit_iv_length: comptime_int) type { + return union { pub const AEAD = AeadType; pub const Hash = HashType; pub const Hmac = crypto.auth.hmac.Hmac(Hash); pub const Hkdf = crypto.kdf.hkdf.Hkdf(Hmac); - client_secret: [Hash.digest_length]u8, - server_secret: [Hash.digest_length]u8, - client_key: [AEAD.key_length]u8, - server_key: [AEAD.key_length]u8, - client_iv: [AEAD.nonce_length]u8, - server_iv: [AEAD.nonce_length]u8, + pub const enc_key_length = AEAD.key_length; + pub const fixed_iv_length = AEAD.nonce_length - explicit_iv_length; + pub const record_iv_length = explicit_iv_length; + pub const mac_length = AEAD.tag_length; + pub const mac_key_length = Hmac.key_length_min; + pub const verify_data_length = 12; + + tls_1_2: Tls_1_2, + tls_1_3: Tls_1_3, + + pub const Tls_1_2 = extern struct { + client_write_MAC_key: [mac_key_length]u8, + server_write_MAC_key: [mac_key_length]u8, + client_write_key: [enc_key_length]u8, + server_write_key: [enc_key_length]u8, + client_write_IV: [fixed_iv_length]u8, + server_write_IV: [fixed_iv_length]u8, + // non-standard entropy + client_salt: [record_iv_length]u8, + }; + + pub const Tls_1_3 = struct { + client_secret: [Hash.digest_length]u8, + server_secret: [Hash.digest_length]u8, + client_key: [AEAD.key_length]u8, + server_key: [AEAD.key_length]u8, + client_iv: [AEAD.nonce_length]u8, + server_iv: [AEAD.nonce_length]u8, + }; }; } /// Encryption parameters for application traffic. pub const ApplicationCipher = union(enum) { - AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256), - AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384), - CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256), - AEGIS_256_SHA512: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512), - AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256), + AES_128_GCM_SHA256: ApplicationCipherT(crypto.aead.aes_gcm.Aes128Gcm, crypto.hash.sha2.Sha256, 8), + AES_256_GCM_SHA384: ApplicationCipherT(crypto.aead.aes_gcm.Aes256Gcm, crypto.hash.sha2.Sha384, 8), + CHACHA20_POLY1305_SHA256: ApplicationCipherT(crypto.aead.chacha_poly.ChaCha20Poly1305, crypto.hash.sha2.Sha256, 0), + AEGIS_256_SHA512: ApplicationCipherT(crypto.aead.aegis.Aegis256, crypto.hash.sha2.Sha512, 0), + AEGIS_128L_SHA256: ApplicationCipherT(crypto.aead.aegis.Aegis128L, crypto.hash.sha2.Sha256, 0), }; +pub fn hmacExpandLabel( + comptime Hmac: type, + secret: []const u8, + label_then_seed: []const []const u8, + comptime len: usize, +) [len]u8 { + const initial_hmac: Hmac = .init(secret); + var a: [Hmac.mac_length]u8 = undefined; + var result: [std.mem.alignForwardAnyAlign(usize, len, Hmac.mac_length)]u8 = undefined; + var index: usize = 0; + while (index < result.len) : (index += Hmac.mac_length) { + var a_hmac = initial_hmac; + if (index > 0) a_hmac.update(&a) else for (label_then_seed) |part| a_hmac.update(part); + a_hmac.final(&a); + + var result_hmac = initial_hmac; + result_hmac.update(&a); + for (label_then_seed) |part| result_hmac.update(part); + result_hmac.final(result[index..][0..Hmac.mac_length]); + } + return result[0..len].*; +} + pub fn hkdfExpandLabel( comptime Hkdf: type, key: [Hkdf.prk_length]u8, @@ -399,38 +588,39 @@ pub fn hmac(comptime Hmac: type, message: []const u8, key: [Hmac.key_length]u8) return result; } -pub inline fn extension(comptime et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { - return int2(@intFromEnum(et)) ++ array(1, bytes); +pub inline fn extension(et: ExtensionType, bytes: anytype) [2 + 2 + bytes.len]u8 { + return int(u16, @intFromEnum(et)) ++ array(u16, u8, bytes); } -pub inline fn array(comptime elem_size: comptime_int, bytes: anytype) [2 + bytes.len]u8 { - comptime assert(bytes.len % elem_size == 0); - return int2(bytes.len) ++ bytes; -} - -pub inline fn enum_array(comptime E: type, comptime tags: []const E) [2 + @sizeOf(E) * tags.len]u8 { - assert(@sizeOf(E) == 2); - var result: [tags.len * 2]u8 = undefined; - for (tags, 0..) |elem, i| { - result[i * 2] = @as(u8, @truncate(@intFromEnum(elem) >> 8)); - result[i * 2 + 1] = @as(u8, @truncate(@intFromEnum(elem))); +pub inline fn array( + comptime Len: type, + comptime Elem: type, + elems: anytype, +) [@divExact(@bitSizeOf(Len), 8) + @divExact(@bitSizeOf(Elem), 8) * elems.len]u8 { + const len_size = @divExact(@bitSizeOf(Len), 8); + const elem_size = @divExact(@bitSizeOf(Elem), 8); + var arr: [len_size + elem_size * elems.len]u8 = undefined; + std.mem.writeInt(Len, arr[0..len_size], @intCast(elem_size * elems.len), .big); + const ElemInt = @Type(.{ .int = .{ .signedness = .unsigned, .bits = @bitSizeOf(Elem) } }); + for (0.., @as([elems.len]Elem, elems)) |index, elem| { + std.mem.writeInt( + ElemInt, + arr[len_size + elem_size * index ..][0..elem_size], + switch (@typeInfo(Elem)) { + .int => @as(Elem, elem), + .@"enum" => @intFromEnum(@as(Elem, elem)), + else => @bitCast(@as(Elem, elem)), + }, + .big, + ); } - return array(2, result); + return arr; } -pub inline fn int2(x: u16) [2]u8 { - return .{ - @as(u8, @truncate(x >> 8)), - @as(u8, @truncate(x)), - }; -} - -pub inline fn int3(x: u24) [3]u8 { - return .{ - @as(u8, @truncate(x >> 16)), - @as(u8, @truncate(x >> 8)), - @as(u8, @truncate(x)), - }; +pub inline fn int(comptime Int: type, val: Int) [@divExact(@bitSizeOf(Int), 8)]u8 { + var arr: [@divExact(@bitSizeOf(Int), 8)]u8 = undefined; + std.mem.writeInt(Int, &arr, val, .big); + return arr; } /// An abstraction to ensure that protocol-parsing code does not perform an @@ -512,9 +702,8 @@ pub const Decoder = struct { else => @compileError("unsupported int type: " ++ @typeName(T)), }, .@"enum" => |info| { - const int = d.decode(info.tag_type); if (info.is_exhaustive) @compileError("exhaustive enum cannot be used"); - return @as(T, @enumFromInt(int)); + return @enumFromInt(d.decode(info.tag_type)); }, else => @compileError("unsupported type: " ++ @typeName(T)), } diff --git a/lib/std/crypto/tls/Client.zig b/lib/std/crypto/tls/Client.zig index 84dbb2167a..20ea485049 100644 --- a/lib/std/crypto/tls/Client.zig +++ b/lib/std/crypto/tls/Client.zig @@ -8,12 +8,12 @@ const assert = std.debug.assert; const Certificate = std.crypto.Certificate; const max_ciphertext_len = tls.max_ciphertext_len; +const hmacExpandLabel = tls.hmacExpandLabel; const hkdfExpandLabel = tls.hkdfExpandLabel; -const int2 = tls.int2; -const int3 = tls.int3; +const int = tls.int; const array = tls.array; -const enum_array = tls.enum_array; +tls_version: tls.ProtocolVersion, read_seq: u64, write_seq: u64, /// The starting index of cleartext bytes inside `partially_read_buffer`. @@ -33,7 +33,7 @@ received_close_notify: bool, /// This makes the application vulnerable to truncation attacks unless the /// application layer itself verifies that the amount of data received equals /// the amount of data expected, such as HTTP with the Content-Length header. -allow_truncation_attacks: bool = false, +allow_truncation_attacks: bool, application_cipher: tls.ApplicationCipher, /// The size is enough to contain exactly one TLSCiphertext record. /// This buffer is segmented into four parts: @@ -44,6 +44,24 @@ application_cipher: tls.ApplicationCipher, /// The fields `partial_cleartext_idx`, `partial_ciphertext_idx`, and /// `partial_ciphertext_end` describe the span of the segments. partially_read_buffer: [tls.max_ciphertext_record_len]u8, +/// If non-null, ssl secrets are logged to a file. Creating such a log file allows other +/// programs with access to that file to decrypt all traffic over this connection. +ssl_key_log: ?struct { + client_key_seq: u64, + server_key_seq: u64, + client_random: [32]u8, + file: std.fs.File, + + fn clientCounter(key_log: *@This()) u64 { + defer key_log.client_key_seq += 1; + return key_log.client_key_seq; + } + + fn serverCounter(key_log: *@This()) u64 { + defer key_log.server_key_seq += 1; + return key_log.server_key_seq; + } +}, /// This is an example of the type that is needed by the read and write /// functions. It can have any fields but it must at least have these @@ -88,6 +106,32 @@ pub const StreamInterface = struct { } }; +pub const Options = struct { + /// How to perform host verification of server certificates. + host: union(enum) { + /// No host verification is performed, which prevents a trusted connection from + /// being established. + no_verification, + /// Verify that the server certificate was issues for a given host. + explicit: []const u8, + }, + /// How to verify the authenticity of server certificates. + ca: union(enum) { + /// No ca verification is performed, which prevents a trusted connection from + /// being established. + no_verification, + /// Verify that the server certificate is a valid self-signed certificate. + /// This provides no authorization guarantees, as anyone can create a + /// self-signed certificate. + self_signed, + /// Verify that the server certificate is authorized by a given ca bundle. + bundle: Certificate.Bundle, + }, + /// If non-null, ssl secrets are logged to this file. Creating such a log file allows + /// other programs with access to that file to decrypt all traffic over this connection. + ssl_key_log_file: ?std.fs.File = null, +}; + pub fn InitError(comptime Stream: type) type { return std.mem.Allocator.Error || Stream.WriteError || Stream.ReadError || tls.AlertDescription.Error || error{ InsufficientEntropy, @@ -136,326 +180,186 @@ pub fn InitError(comptime Stream: type) type { }; } -/// Initiates a TLS handshake and establishes a TLSv1.3 session with `stream`, which +/// Initiates a TLS handshake and establishes a TLSv1.2 or TLSv1.3 session with `stream`, which /// must conform to `StreamInterface`. /// /// `host` is only borrowed during this function call. -pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) InitError(@TypeOf(stream))!Client { +pub fn init(stream: anytype, options: Options) InitError(@TypeOf(stream))!Client { + const host = switch (options.host) { + .no_verification => "", + .explicit => |host| host, + }; const host_len: u16 = @intCast(host.len); - var random_buffer: [128]u8 = undefined; + var random_buffer: [176]u8 = undefined; crypto.random.bytes(&random_buffer); - const hello_rand = random_buffer[0..32].*; + const client_hello_rand = random_buffer[0..32].*; + var key_seq: u64 = 0; + var server_hello_rand: [32]u8 = undefined; const legacy_session_id = random_buffer[32..64].*; - const x25519_kp_seed = random_buffer[64..96].*; - const secp256r1_kp_seed = random_buffer[96..128].*; - const x25519_kp = crypto.dh.X25519.KeyPair.create(x25519_kp_seed) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. + var key_share = KeyShare.init(random_buffer[64..176].*) catch |err| switch (err) { + // Only possible to happen if the seed is all zeroes. error.IdentityElement => return error.InsufficientEntropy, }; - const secp256r1_kp = crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair.create(secp256r1_kp_seed) catch |err| switch (err) { - // Only possible to happen if the private key is all zeroes. - error.IdentityElement => return error.InsufficientEntropy, - }; - const ml_kem768_kp = crypto.kem.ml_kem.MLKem768.KeyPair.create(null) catch {}; - const extensions_payload = - tls.extension(.supported_versions, [_]u8{ - 0x02, // byte length of supported versions - 0x03, 0x04, // TLS 1.3 - }) ++ tls.extension(.signature_algorithms, enum_array(tls.SignatureScheme, &.{ + const extensions_payload = tls.extension(.supported_versions, array(u8, tls.ProtocolVersion, .{ + .tls_1_3, + .tls_1_2, + })) ++ tls.extension(.signature_algorithms, array(u16, tls.SignatureScheme, .{ .ecdsa_secp256r1_sha256, .ecdsa_secp384r1_sha384, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, .rsa_pss_rsae_sha256, .rsa_pss_rsae_sha384, .rsa_pss_rsae_sha512, + .rsa_pss_pss_sha256, + .rsa_pss_pss_sha384, + .rsa_pss_pss_sha512, + .rsa_pkcs1_sha1, .ed25519, - })) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{ + })) ++ tls.extension(.supported_groups, array(u16, tls.NamedGroup, .{ .x25519_ml_kem768, .secp256r1, + .secp384r1, .x25519, - })) ++ tls.extension( - .key_share, - array(1, int2(@intFromEnum(tls.NamedGroup.x25519)) ++ - array(1, x25519_kp.public_key) ++ - int2(@intFromEnum(tls.NamedGroup.secp256r1)) ++ - array(1, secp256r1_kp.public_key.toUncompressedSec1()) ++ - int2(@intFromEnum(tls.NamedGroup.x25519_ml_kem768)) ++ - array(1, x25519_kp.public_key ++ ml_kem768_kp.public_key.toBytes())), - ) ++ - int2(@intFromEnum(tls.ExtensionType.server_name)) ++ - int2(host_len + 5) ++ // byte length of this extension payload - int2(host_len + 3) ++ // server_name_list byte count - [1]u8{0x00} ++ // name_type - int2(host_len); + })) ++ tls.extension(.psk_key_exchange_modes, array(u8, tls.PskKeyExchangeMode, .{ + .psk_dhe_ke, + })) ++ tls.extension(.key_share, array( + u16, + u8, + int(u16, @intFromEnum(tls.NamedGroup.x25519_ml_kem768)) ++ + array(u16, u8, key_share.ml_kem768_kp.public_key.toBytes() ++ key_share.x25519_kp.public_key) ++ + int(u16, @intFromEnum(tls.NamedGroup.secp256r1)) ++ + array(u16, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1()) ++ + int(u16, @intFromEnum(tls.NamedGroup.secp384r1)) ++ + array(u16, u8, key_share.secp384r1_kp.public_key.toUncompressedSec1()) ++ + int(u16, @intFromEnum(tls.NamedGroup.x25519)) ++ + array(u16, u8, key_share.x25519_kp.public_key), + )); + const server_name_extension = int(u16, @intFromEnum(tls.ExtensionType.server_name)) ++ + int(u16, 2 + 1 + 2 + host_len) ++ // byte length of this extension payload + int(u16, 1 + 2 + host_len) ++ // server_name_list byte count + .{0x00} ++ // name_type + int(u16, host_len); + const server_name_extension_len = switch (options.host) { + .no_verification => 0, + .explicit => server_name_extension.len + host_len, + }; const extensions_header = - int2(@intCast(extensions_payload.len + host_len)) ++ - extensions_payload; - - const legacy_compression_methods = 0x0100; + int(u16, @intCast(extensions_payload.len + server_name_extension_len)) ++ + extensions_payload ++ + server_name_extension; const client_hello = - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - hello_rand ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + client_hello_rand ++ [1]u8{32} ++ legacy_session_id ++ cipher_suites ++ - int2(legacy_compression_methods) ++ + array(u8, tls.CompressionMethod, .{.null}) ++ extensions_header; - const out_handshake = - [_]u8{@intFromEnum(tls.HandshakeType.client_hello)} ++ - int3(@intCast(client_hello.len + host_len)) ++ + const out_handshake = .{@intFromEnum(tls.HandshakeType.client_hello)} ++ + int(u24, @intCast(client_hello.len - server_name_extension.len + server_name_extension_len)) ++ client_hello; - const plaintext_header = [_]u8{ - @intFromEnum(tls.ContentType.handshake), - 0x03, 0x01, // legacy_record_version - } ++ int2(@intCast(out_handshake.len + host_len)) ++ out_handshake; + const cleartext_header_buf = .{@intFromEnum(tls.ContentType.handshake)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_0)) ++ + int(u16, @intCast(out_handshake.len - server_name_extension.len + server_name_extension_len)) ++ + out_handshake; + const cleartext_header = switch (options.host) { + .no_verification => cleartext_header_buf[0 .. cleartext_header_buf.len - server_name_extension.len], + .explicit => &cleartext_header_buf, + }; { var iovecs = [_]std.posix.iovec_const{ - .{ - .base = &plaintext_header, - .len = plaintext_header.len, - }, - .{ - .base = host.ptr, - .len = host.len, - }, + .{ .base = cleartext_header.ptr, .len = cleartext_header.len }, + .{ .base = host.ptr, .len = host.len }, }; - try stream.writevAll(&iovecs); + try stream.writevAll(iovecs[0..if (host.len == 0) 1 else 2]); } - const client_hello_bytes1 = plaintext_header[5..]; - - var handshake_cipher: tls.HandshakeCipher = undefined; - var handshake_buffer: [8000]u8 = undefined; - var d: tls.Decoder = .{ .buf = &handshake_buffer }; - { - try d.readAtLeastOurAmt(stream, tls.record_header_len); - const ct = d.decode(tls.ContentType); - d.skip(2); // legacy_record_version - const record_len = d.decode(u16); - try d.readAtLeast(stream, record_len); - const server_hello_fragment = d.buf[d.idx..][0..record_len]; - var ptd = try d.sub(record_len); - switch (ct) { - .alert => { - try ptd.ensure(2); - const level = ptd.decode(tls.AlertLevel); - const desc = ptd.decode(tls.AlertDescription); - _ = level; - - // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - try ptd.ensure(4); - const handshake_type = ptd.decode(tls.HandshakeType); - if (handshake_type != .server_hello) return error.TlsUnexpectedMessage; - const length = ptd.decode(u24); - var hsd = try ptd.sub(length); - try hsd.ensure(2 + 32 + 1 + 32 + 2 + 1 + 2); - const legacy_version = hsd.decode(u16); - const random = hsd.array(32); - if (mem.eql(u8, random, &tls.hello_retry_request_sequence)) { - // This is a HelloRetryRequest message. This client implementation - // does not expect to get one. - return error.TlsUnexpectedMessage; - } - const legacy_session_id_echo_len = hsd.decode(u8); - if (legacy_session_id_echo_len != 32) return error.TlsIllegalParameter; - const legacy_session_id_echo = hsd.array(32); - if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) - return error.TlsIllegalParameter; - const cipher_suite_tag = hsd.decode(tls.CipherSuite); - hsd.skip(1); // legacy_compression_method - const extensions_size = hsd.decode(u16); - var all_extd = try hsd.sub(extensions_size); - var supported_version: u16 = 0; - var shared_key: []const u8 = undefined; - var have_shared_key = false; - while (!all_extd.eof()) { - try all_extd.ensure(2 + 2); - const et = all_extd.decode(tls.ExtensionType); - const ext_size = all_extd.decode(u16); - var extd = try all_extd.sub(ext_size); - switch (et) { - .supported_versions => { - if (supported_version != 0) return error.TlsIllegalParameter; - try extd.ensure(2); - supported_version = extd.decode(u16); - }, - .key_share => { - if (have_shared_key) return error.TlsIllegalParameter; - have_shared_key = true; - try extd.ensure(4); - const named_group = extd.decode(tls.NamedGroup); - const key_size = extd.decode(u16); - try extd.ensure(key_size); - switch (named_group) { - .x25519_ml_kem768 => { - const xksl = crypto.dh.X25519.public_length; - const hksl = xksl + crypto.kem.ml_kem.MLKem768.ciphertext_length; - if (key_size != hksl) - return error.TlsIllegalParameter; - const server_ks = extd.array(hksl); - - shared_key = &((crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - server_ks[0..xksl].*, - ) catch return error.TlsDecryptFailure) ++ (ml_kem768_kp.secret_key.decaps( - server_ks[xksl..hksl], - ) catch return error.TlsDecryptFailure)); - }, - .x25519 => { - const ksl = crypto.dh.X25519.public_length; - if (key_size != ksl) return error.TlsIllegalParameter; - const server_pub_key = extd.array(ksl); - - shared_key = &(crypto.dh.X25519.scalarmult( - x25519_kp.secret_key, - server_pub_key.*, - ) catch return error.TlsDecryptFailure); - }, - .secp256r1 => { - const server_pub_key = extd.slice(key_size); - - const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; - const pk = PublicKey.fromSec1(server_pub_key) catch { - return error.TlsDecryptFailure; - }; - const mul = pk.p.mulPublic(secp256r1_kp.secret_key.bytes, .big) catch { - return error.TlsDecryptFailure; - }; - shared_key = &mul.affineCoordinates().x.toBytes(.big); - }, - else => { - return error.TlsIllegalParameter; - }, - } - }, - else => {}, - } - } - if (!have_shared_key) return error.TlsIllegalParameter; - - const tls_version = if (supported_version == 0) legacy_version else supported_version; - if (tls_version != @intFromEnum(tls.ProtocolVersion.tls_1_3)) - return error.TlsIllegalParameter; - - switch (cipher_suite_tag) { - inline .AES_128_GCM_SHA256, - .AES_256_GCM_SHA384, - .CHACHA20_POLY1305_SHA256, - .AEGIS_256_SHA512, - .AEGIS_128L_SHA256, - => |tag| { - const P = std.meta.TagPayloadByName(tls.HandshakeCipher, @tagName(tag)); - handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag), .{ - .handshake_secret = undefined, - .master_secret = undefined, - .client_handshake_key = undefined, - .server_handshake_key = undefined, - .client_finished_key = undefined, - .server_finished_key = undefined, - .client_handshake_iv = undefined, - .server_handshake_iv = undefined, - .transcript_hash = P.Hash.init(.{}), - }); - const p = &@field(handshake_cipher, @tagName(tag)); - p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1 - p.transcript_hash.update(host); // Client Hello part 2 - p.transcript_hash.update(server_hello_fragment); - const hello_hash = p.transcript_hash.peek(); - const zeroes = [1]u8{0} ** P.Hash.digest_length; - const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); - const empty_hash = tls.emptyHash(P.Hash); - const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); - p.handshake_secret = P.Hkdf.extract(&hs_derived_secret, shared_key); - const ap_derived_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); - p.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); - const client_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); - p.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); - p.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); - p.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - p.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); - }, - else => { - return error.TlsIllegalParameter; - }, - } - }, - else => return error.TlsUnexpectedMessage, - } - } - - // This is used for two purposes: + var tls_version: tls.ProtocolVersion = undefined; + // These are used for two purposes: // * Detect whether a certificate is the first one presented, in which case // we need to verify the host name. + var cert_index: usize = 0; // * Flip back and forth between the two cleartext buffers in order to keep // the previous certificate in memory so that it can be verified by the // next one. - var cert_index: usize = 0; + var cert_buf_index: usize = 0; + var write_seq: u64 = 0; var read_seq: u64 = 0; var prev_cert: Certificate.Parsed = undefined; - // Set to true once a trust chain has been established from the first - // certificate to a root CA. + const CipherState = enum { + /// No cipher is in use + cleartext, + /// Handshake cipher is in use + handshake, + /// Application cipher is in use + application, + }; + var pending_cipher_state: CipherState = .cleartext; + var cipher_state = pending_cipher_state; const HandshakeState = enum { + /// In this state we expect only a server hello message. + hello, /// In this state we expect only an encrypted_extensions message. encrypted_extensions, - /// In this state we expect certificate messages. + /// In this state we expect certificate handshake messages. certificate, /// In this state we expect certificate or certificate_verify messages. /// certificate messages are ignored since the trust chain is already /// established. trust_chain_established, - /// In this state, we expect only the finished message. + /// In this state, we expect only the server_hello_done handshake message. + server_hello_done, + /// In this state, we expect only the finished handshake message. finished, }; - var handshake_state: HandshakeState = .encrypted_extensions; - var cleartext_bufs: [2][8000]u8 = undefined; - var main_cert_pub_key_algo: Certificate.AlgorithmCategory = undefined; - var main_cert_pub_key_buf: [600]u8 = undefined; - var main_cert_pub_key_len: u16 = undefined; + var handshake_state: HandshakeState = .hello; + var handshake_cipher: tls.HandshakeCipher = undefined; + var main_cert_pub_key: CertificatePublicKey = undefined; const now_sec = std.time.timestamp(); - while (true) { + var cleartext_fragment_start: usize = 0; + var cleartext_fragment_end: usize = 0; + var cleartext_bufs: [2][tls.max_ciphertext_inner_record_len]u8 = undefined; + var handshake_buffer: [tls.max_ciphertext_record_len]u8 = undefined; + var d: tls.Decoder = .{ .buf = &handshake_buffer }; + fragment: while (true) { try d.readAtLeastOurAmt(stream, tls.record_header_len); - const record_header = d.buf[d.idx..][0..5]; - const ct = d.decode(tls.ContentType); + const record_header = d.buf[d.idx..][0..tls.record_header_len]; + const record_ct = d.decode(tls.ContentType); d.skip(2); // legacy_version const record_len = d.decode(u16); try d.readAtLeast(stream, record_len); var record_decoder = try d.sub(record_len); - switch (ct) { - .change_cipher_spec => { - try record_decoder.ensure(1); - if (record_decoder.decode(u8) != 0x01) return error.TlsIllegalParameter; - }, - .application_data => { - const cleartext_buf = &cleartext_bufs[cert_index % 2]; - - const cleartext = switch (handshake_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const ciphertext_len = record_len - P.AEAD.tag_length; - try record_decoder.ensure(ciphertext_len + P.AEAD.tag_length); - const ciphertext = record_decoder.slice(ciphertext_len); - if (ciphertext.len > cleartext_buf.len) return error.TlsRecordOverflow; - const cleartext = cleartext_buf[0..ciphertext.len]; + var ctd, const ct = content: switch (cipher_state) { + .cleartext => .{ record_decoder, record_ct }, + .handshake => { + std.debug.assert(tls_version == .tls_1_3); + if (record_ct != .application_data) return error.TlsUnexpectedMessage; + try record_decoder.ensure(record_len); + const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; + switch (handshake_cipher) { + inline else => |*p| { + const pv = &p.version.tls_1_3; + const P = @TypeOf(p.*).A; + if (record_len < P.AEAD.tag_length) return error.TlsRecordOverflow; + const ciphertext = record_decoder.slice(record_len - P.AEAD.tag_length); + const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..]; + if (ciphertext.len > cleartext_fragment_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_fragment_buf[0..ciphertext.len]; const auth_tag = record_decoder.array(P.AEAD.tag_length).*; const nonce = if (builtin.zig_backend == .stage2_x86_64 and P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) nonce: { - var nonce = p.server_handshake_iv; + var nonce = pv.server_handshake_iv; const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ read_seq, .big); break :nonce nonce; @@ -463,268 +367,559 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In const V = @Vector(P.AEAD.nonce_length, u8); const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); const operand: V = pad ++ @as([8]u8, @bitCast(big(read_seq))); - break :nonce @as(V, p.server_handshake_iv) ^ operand; + break :nonce @as(V, pv.server_handshake_iv) ^ operand; }; - read_seq += 1; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, p.server_handshake_key) catch + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, record_header, nonce, pv.server_handshake_key) catch return error.TlsBadRecordMac; - break :c @constCast(mem.trimRight(u8, cleartext, "\x00")); + cleartext_fragment_end += std.mem.trimRight(u8, cleartext, "\x00").len; }, - }; + } + read_seq += 1; + cleartext_fragment_end -= 1; + const ct: tls.ContentType = @enumFromInt(cleartext_buf[cleartext_fragment_end]); + if (ct != .handshake) return error.TlsUnexpectedMessage; + break :content .{ tls.Decoder.fromTheirSlice(@constCast(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end])), ct }; + }, + .application => { + std.debug.assert(tls_version == .tls_1_2); + if (record_ct != .handshake) return error.TlsUnexpectedMessage; + try record_decoder.ensure(record_len); + const cleartext_buf = &cleartext_bufs[cert_buf_index % 2]; + switch (handshake_cipher) { + inline else => |*p| { + const pv = &p.version.tls_1_2; + const P = @TypeOf(p.*).A; + if (record_len < P.record_iv_length + P.mac_length) return error.TlsRecordOverflow; + const message_len: u16 = record_len - P.record_iv_length - P.mac_length; + const cleartext_fragment_buf = cleartext_buf[cleartext_fragment_end..]; + if (message_len > cleartext_fragment_buf.len) return error.TlsRecordOverflow; + const cleartext = cleartext_fragment_buf[0..message_len]; + const ad = std.mem.toBytes(big(read_seq)) ++ + record_header[0 .. 1 + 2] ++ + std.mem.toBytes(big(message_len)); + const record_iv = record_decoder.array(P.record_iv_length).*; + const masked_read_seq = read_seq & + comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); + const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.app_cipher.server_write_IV ++ record_iv; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ masked_read_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); + break :nonce @as(V, pv.app_cipher.server_write_IV ++ record_iv) ^ operand; + }; + const ciphertext = record_decoder.slice(message_len); + const auth_tag = record_decoder.array(P.mac_length); + P.AEAD.decrypt(cleartext, ciphertext, auth_tag.*, ad, nonce, pv.app_cipher.server_write_key) catch return error.TlsBadRecordMac; + cleartext_fragment_end += message_len; + }, + } + read_seq += 1; + break :content .{ tls.Decoder.fromTheirSlice(cleartext_buf[cleartext_fragment_start..cleartext_fragment_end]), record_ct }; + }, + }; + switch (ct) { + .alert => { + ctd.ensure(2) catch continue :fragment; + const level = ctd.decode(tls.AlertLevel); + const desc = ctd.decode(tls.AlertDescription); + _ = level; - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - if (inner_ct != .handshake) return error.TlsUnexpectedMessage; - - var ctd = tls.Decoder.fromTheirSlice(cleartext[0 .. cleartext.len - 1]); - while (true) { - try ctd.ensure(4); - const handshake_type = ctd.decode(tls.HandshakeType); - const handshake_len = ctd.decode(u24); - var hsd = try ctd.sub(handshake_len); - const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; - const handshake = ctd.buf[ctd.idx - handshake_len .. ctd.idx]; - switch (handshake_type) { - .encrypted_extensions => { - if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; - handshake_state = .certificate; - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } + // if this isn't a error alert, then it's a closure alert, which makes no sense in a handshake + try desc.toError(); + // TODO: handle server-side closures + return error.TlsUnexpectedMessage; + }, + .change_cipher_spec => { + ctd.ensure(1) catch continue :fragment; + if (ctd.decode(tls.ChangeCipherSpecType) != .change_cipher_spec) return error.TlsIllegalParameter; + cipher_state = pending_cipher_state; + }, + .handshake => while (true) { + ctd.ensure(4) catch continue :fragment; + const handshake_type = ctd.decode(tls.HandshakeType); + const handshake_len = ctd.decode(u24); + var hsd = ctd.sub(handshake_len) catch continue :fragment; + const wrapped_handshake = ctd.buf[ctd.idx - handshake_len - 4 .. ctd.idx]; + switch (handshake_type) { + .server_hello => { + if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; + if (handshake_state != .hello) return error.TlsUnexpectedMessage; + try hsd.ensure(2 + 32 + 1); + const legacy_version = hsd.decode(u16); + @memcpy(&server_hello_rand, hsd.array(32)); + if (mem.eql(u8, &server_hello_rand, &tls.hello_retry_request_sequence)) { + // This is a HelloRetryRequest message. This client implementation + // does not expect to get one. + return error.TlsUnexpectedMessage; + } + const legacy_session_id_echo_len = hsd.decode(u8); + try hsd.ensure(legacy_session_id_echo_len + 2 + 1); + const legacy_session_id_echo = hsd.slice(legacy_session_id_echo_len); + const cipher_suite_tag = hsd.decode(tls.CipherSuite); + hsd.skip(1); // legacy_compression_method + var supported_version: ?u16 = null; + if (!hsd.eof()) { try hsd.ensure(2); - const total_ext_size = hsd.decode(u16); - var all_extd = try hsd.sub(total_ext_size); + const extensions_size = hsd.decode(u16); + var all_extd = try hsd.sub(extensions_size); while (!all_extd.eof()) { - try all_extd.ensure(4); + try all_extd.ensure(2 + 2); const et = all_extd.decode(tls.ExtensionType); const ext_size = all_extd.decode(u16); - const extd = try all_extd.sub(ext_size); - _ = extd; + var extd = try all_extd.sub(ext_size); switch (et) { - .server_name => {}, + .supported_versions => { + if (supported_version) |_| return error.TlsIllegalParameter; + try extd.ensure(2); + supported_version = extd.decode(u16); + }, + .key_share => { + if (key_share.getSharedSecret()) |_| return error.TlsIllegalParameter; + try extd.ensure(4); + const named_group = extd.decode(tls.NamedGroup); + const key_size = extd.decode(u16); + try extd.ensure(key_size); + try key_share.exchange(named_group, extd.slice(key_size)); + }, else => {}, } } - }, - .certificate => cert: { - switch (handshake_cipher) { - inline else => |*p| p.transcript_hash.update(wrapped_handshake), - } - switch (handshake_state) { - .certificate => {}, - .trust_chain_established => break :cert, - else => return error.TlsUnexpectedMessage, - } - try hsd.ensure(1 + 4); - const cert_req_ctx_len = hsd.decode(u8); - if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; - const certs_size = hsd.decode(u24); - var certs_decoder = try hsd.sub(certs_size); - while (!certs_decoder.eof()) { - try certs_decoder.ensure(3); - const cert_size = certs_decoder.decode(u24); - const certd = try certs_decoder.sub(cert_size); + } - const subject_cert: Certificate = .{ - .buffer = certd.buf, - .index = @intCast(certd.idx), - }; - const subject = try subject_cert.parse(); - if (cert_index == 0) { - // Verify the host on the first certificate. - try subject.verifyHostName(host); + tls_version = @enumFromInt(supported_version orelse legacy_version); + switch (tls_version) { + .tls_1_3 => if (!mem.eql(u8, legacy_session_id_echo, &legacy_session_id)) return error.TlsIllegalParameter, + .tls_1_2 => if (mem.eql(u8, server_hello_rand[24..31], "DOWNGRD") and + server_hello_rand[31] >> 1 == 0x00) return error.TlsIllegalParameter, + else => return error.TlsIllegalParameter, + } - // Keep track of the public key for the - // certificate_verify message later. - main_cert_pub_key_algo = subject.pub_key_algo; - const pub_key = subject.pubKey(); - if (pub_key.len > main_cert_pub_key_buf.len) - return error.CertificatePublicKeyInvalid; - @memcpy(main_cert_pub_key_buf[0..pub_key.len], pub_key); - main_cert_pub_key_len = @intCast(pub_key.len); - } else { - try prev_cert.verify(subject, now_sec); + switch (cipher_suite_tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_256_SHA512, + .AEGIS_128L_SHA256, + + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => |tag| { + handshake_cipher = @unionInit(tls.HandshakeCipher, @tagName(tag.with()), .{ + .transcript_hash = .init(.{}), + .version = undefined, + }); + const p = &@field(handshake_cipher, @tagName(tag.with())); + p.transcript_hash.update(cleartext_header[tls.record_header_len..]); // Client Hello part 1 + p.transcript_hash.update(host); // Client Hello part 2 + p.transcript_hash.update(wrapped_handshake); + }, + + else => return error.TlsIllegalParameter, + } + switch (tls_version) { + .tls_1_3 => { + switch (cipher_suite_tag) { + inline .AES_128_GCM_SHA256, + .AES_256_GCM_SHA384, + .CHACHA20_POLY1305_SHA256, + .AEGIS_256_SHA512, + .AEGIS_128L_SHA256, + => |tag| { + const sk = key_share.getSharedSecret() orelse return error.TlsIllegalParameter; + const p = &@field(handshake_cipher, @tagName(tag.with())); + const P = @TypeOf(p.*).A; + const hello_hash = p.transcript_hash.peek(); + const zeroes = [1]u8{0} ** P.Hash.digest_length; + const early_secret = P.Hkdf.extract(&[1]u8{0}, &zeroes); + const empty_hash = tls.emptyHash(P.Hash); + p.version = .{ .tls_1_3 = undefined }; + const pv = &p.version.tls_1_3; + const hs_derived_secret = hkdfExpandLabel(P.Hkdf, early_secret, "derived", &empty_hash, P.Hash.digest_length); + pv.handshake_secret = P.Hkdf.extract(&hs_derived_secret, sk); + const ap_derived_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "derived", &empty_hash, P.Hash.digest_length); + pv.master_secret = P.Hkdf.extract(&ap_derived_secret, &zeroes); + const client_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "c hs traffic", &hello_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.handshake_secret, "s hs traffic", &hello_hash, P.Hash.digest_length); + if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + .client_random = &client_hello_rand, + }, .{ + .SERVER_HANDSHAKE_TRAFFIC_SECRET = &server_secret, + .CLIENT_HANDSHAKE_TRAFFIC_SECRET = &client_secret, + }); + pv.client_finished_key = hkdfExpandLabel(P.Hkdf, client_secret, "finished", "", P.Hmac.key_length); + pv.server_finished_key = hkdfExpandLabel(P.Hkdf, server_secret, "finished", "", P.Hmac.key_length); + pv.client_handshake_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + pv.server_handshake_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + pv.client_handshake_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); + pv.server_handshake_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + }, + else => return error.TlsIllegalParameter, } + pending_cipher_state = .handshake; + handshake_state = .encrypted_extensions; + }, + .tls_1_2 => switch (cipher_suite_tag) { + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + => handshake_state = .certificate, + else => return error.TlsIllegalParameter, + }, + else => return error.TlsIllegalParameter, + } + }, + .encrypted_extensions => { + if (tls_version != .tls_1_3) return error.TlsUnexpectedMessage; + if (cipher_state != .handshake) return error.TlsUnexpectedMessage; + if (handshake_state != .encrypted_extensions) return error.TlsUnexpectedMessage; + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + try hsd.ensure(2); + const total_ext_size = hsd.decode(u16); + var all_extd = try hsd.sub(total_ext_size); + while (!all_extd.eof()) { + try all_extd.ensure(4); + const et = all_extd.decode(tls.ExtensionType); + const ext_size = all_extd.decode(u16); + const extd = try all_extd.sub(ext_size); + _ = extd; + switch (et) { + .server_name => {}, + else => {}, + } + } + handshake_state = .certificate; + }, + .certificate => cert: { + if (cipher_state == .application) return error.TlsUnexpectedMessage; + switch (handshake_state) { + .certificate => {}, + .trust_chain_established => break :cert, + else => return error.TlsUnexpectedMessage, + } + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } - if (ca_bundle.verify(subject, now_sec)) |_| { - handshake_state = .trust_chain_established; - break :cert; - } else |err| switch (err) { - error.CertificateIssuerNotFound => {}, - else => |e| return e, - } - - prev_cert = subject; - cert_index += 1; + switch (tls_version) { + .tls_1_3 => { + try hsd.ensure(1 + 3); + const cert_req_ctx_len = hsd.decode(u8); + if (cert_req_ctx_len != 0) return error.TlsIllegalParameter; + }, + .tls_1_2 => try hsd.ensure(3), + else => unreachable, + } + const certs_size = hsd.decode(u24); + var certs_decoder = try hsd.sub(certs_size); + while (!certs_decoder.eof()) { + try certs_decoder.ensure(3); + const cert_size = certs_decoder.decode(u24); + const certd = try certs_decoder.sub(cert_size); + if (tls_version == .tls_1_3) { try certs_decoder.ensure(2); const total_ext_size = certs_decoder.decode(u16); const all_extd = try certs_decoder.sub(total_ext_size); _ = all_extd; } - }, - .certificate_verify => { - switch (handshake_state) { - .trust_chain_established => handshake_state = .finished, - .certificate => return error.TlsCertificateNotVerified, - else => return error.TlsUnexpectedMessage, + + const subject_cert: Certificate = .{ + .buffer = certd.buf, + .index = @intCast(certd.idx), + }; + const subject = try subject_cert.parse(); + if (cert_index == 0) { + // Verify the host on the first certificate. + switch (options.host) { + .no_verification => {}, + .explicit => try subject.verifyHostName(host), + } + + // Keep track of the public key for the + // certificate_verify message later. + try main_cert_pub_key.init(subject.pub_key_algo, subject.pubKey()); + } else { + try prev_cert.verify(subject, now_sec); } - try hsd.ensure(4); - const scheme = hsd.decode(tls.SignatureScheme); - const sig_len = hsd.decode(u16); - try hsd.ensure(sig_len); - const encoded_sig = hsd.slice(sig_len); - const max_digest_len = 64; - var verify_buffer: [64 + 34 + max_digest_len]u8 = - ([1]u8{0x20} ** 64) ++ - "TLS 1.3, server CertificateVerify\x00".* ++ - @as([max_digest_len]u8, undefined); - - const verify_bytes = switch (handshake_cipher) { - inline else => |*p| v: { - const transcript_digest = p.transcript_hash.peek(); - verify_buffer[verify_buffer.len - max_digest_len ..][0..transcript_digest.len].* = transcript_digest; - p.transcript_hash.update(wrapped_handshake); - break :v verify_buffer[0 .. verify_buffer.len - max_digest_len + transcript_digest.len]; + switch (options.ca) { + .no_verification => { + handshake_state = .trust_chain_established; + break :cert; }, - }; - const main_cert_pub_key = main_cert_pub_key_buf[0..main_cert_pub_key_len]; - - switch (scheme) { - inline .ecdsa_secp256r1_sha256, - .ecdsa_secp384r1_sha384, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .X9_62_id_ecPublicKey) - return error.TlsBadSignatureScheme; - const Ecdsa = SchemeEcdsa(comptime_scheme); - const sig = try Ecdsa.Signature.fromDer(encoded_sig); - const key = try Ecdsa.PublicKey.fromSec1(main_cert_pub_key); - try sig.verify(verify_bytes, key); + .self_signed => { + try subject.verify(subject, now_sec); + handshake_state = .trust_chain_established; + break :cert; }, - inline .rsa_pss_rsae_sha256, - .rsa_pss_rsae_sha384, - .rsa_pss_rsae_sha512, - => |comptime_scheme| { - if (main_cert_pub_key_algo != .rsaEncryption) - return error.TlsBadSignatureScheme; - - const Hash = SchemeHash(comptime_scheme); - const rsa = Certificate.rsa; - const components = try rsa.PublicKey.parseDer(main_cert_pub_key); - const exponent = components.exponent; - const modulus = components.modulus; - switch (modulus.len) { - inline 128, 256, 512 => |modulus_len| { - const key = try rsa.PublicKey.fromBytes(exponent, modulus); - const sig = rsa.PSSSignature.fromBytes(modulus_len, encoded_sig); - try rsa.PSSSignature.verify(modulus_len, sig, verify_bytes, key, Hash); - }, - else => { - return error.TlsBadRsaSignatureBitCount; - }, - } - }, - inline .ed25519 => |comptime_scheme| { - if (main_cert_pub_key_algo != .curveEd25519) return error.TlsBadSignatureScheme; - const Eddsa = SchemeEddsa(comptime_scheme); - if (encoded_sig.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; - const sig = Eddsa.Signature.fromBytes(encoded_sig[0..Eddsa.Signature.encoded_length].*); - if (main_cert_pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; - const key = try Eddsa.PublicKey.fromBytes(main_cert_pub_key[0..Eddsa.PublicKey.encoded_length].*); - try sig.verify(verify_bytes, key); - }, - else => { - return error.TlsBadSignatureScheme; + .bundle => |ca_bundle| if (ca_bundle.verify(subject, now_sec)) |_| { + handshake_state = .trust_chain_established; + break :cert; + } else |err| switch (err) { + error.CertificateIssuerNotFound => {}, + else => |e| return e, }, } - }, - .finished => { - if (handshake_state != .finished) return error.TlsUnexpectedMessage; - // This message is to trick buggy proxies into behaving correctly. - const client_change_cipher_spec_msg = [_]u8{ - @intFromEnum(tls.ContentType.change_cipher_spec), - 0x03, 0x03, // legacy protocol version - 0x00, 0x01, // length - 0x01, - }; - const app_cipher = switch (handshake_cipher) { - inline else => |*p, tag| c: { - const P = @TypeOf(p.*); + + prev_cert = subject; + cert_index += 1; + } + cert_buf_index += 1; + }, + .server_key_exchange => { + if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage; + if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; + switch (handshake_state) { + .trust_chain_established => {}, + .certificate => return error.TlsCertificateNotVerified, + else => return error.TlsUnexpectedMessage, + } + + switch (handshake_cipher) { + inline else => |*p| p.transcript_hash.update(wrapped_handshake), + } + try hsd.ensure(1 + 2 + 1); + const curve_type = hsd.decode(u8); + if (curve_type != 0x03) return error.TlsIllegalParameter; // named_curve + const named_group = hsd.decode(tls.NamedGroup); + const key_size = hsd.decode(u8); + try hsd.ensure(key_size); + const server_pub_key = hsd.slice(key_size); + try main_cert_pub_key.verifySignature(&hsd, &.{ &client_hello_rand, &server_hello_rand, hsd.buf[0..hsd.idx] }); + try key_share.exchange(named_group, server_pub_key); + handshake_state = .server_hello_done; + }, + .server_hello_done => { + if (tls_version != .tls_1_2) return error.TlsUnexpectedMessage; + if (cipher_state != .cleartext) return error.TlsUnexpectedMessage; + if (handshake_state != .server_hello_done) return error.TlsUnexpectedMessage; + + const client_key_exchange_msg = .{@intFromEnum(tls.ContentType.handshake)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, u8, .{@intFromEnum(tls.HandshakeType.client_key_exchange)} ++ + array(u24, u8, array(u8, u8, key_share.secp256r1_kp.public_key.toUncompressedSec1()))); + const client_change_cipher_spec_msg = .{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, tls.ChangeCipherSpecType, .{.change_cipher_spec}); + const pre_master_secret = key_share.getSharedSecret().?; + switch (handshake_cipher) { + inline else => |*p| { + const P = @TypeOf(p.*).A; + p.transcript_hash.update(wrapped_handshake); + p.transcript_hash.update(client_key_exchange_msg[tls.record_header_len..]); + const master_secret = hmacExpandLabel(P.Hmac, pre_master_secret, &.{ + "master secret", + &client_hello_rand, + &server_hello_rand, + }, 48); + if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + .client_random = &client_hello_rand, + }, .{ + .CLIENT_RANDOM = &master_secret, + }); + const key_block = hmacExpandLabel( + P.Hmac, + &master_secret, + &.{ "key expansion", &server_hello_rand, &client_hello_rand }, + @sizeOf(P.Tls_1_2), + ); + const client_verify_cleartext = .{@intFromEnum(tls.HandshakeType.finished)} ++ + array(u24, u8, hmacExpandLabel( + P.Hmac, + &master_secret, + &.{ "client finished", &p.transcript_hash.peek() }, + P.verify_data_length, + )); + p.transcript_hash.update(&client_verify_cleartext); + p.version = .{ .tls_1_2 = .{ + .expected_server_verify_data = hmacExpandLabel( + P.Hmac, + &master_secret, + &.{ "server finished", &p.transcript_hash.finalResult() }, + P.verify_data_length, + ), + .app_cipher = std.mem.bytesToValue(P.Tls_1_2, &key_block), + } }; + const pv = &p.version.tls_1_2; + const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.app_cipher.client_write_IV ++ pv.app_cipher.client_salt; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ write_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(write_seq))); + break :nonce @as(V, pv.app_cipher.client_write_IV ++ pv.app_cipher.client_salt) ^ operand; + }; + var client_verify_msg = .{@intFromEnum(tls.ContentType.handshake)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, u8, nonce[P.fixed_iv_length..].* ++ + @as([client_verify_cleartext.len + P.mac_length]u8, undefined)); + P.AEAD.encrypt( + client_verify_msg[client_verify_msg.len - P.mac_length - + client_verify_cleartext.len ..][0..client_verify_cleartext.len], + client_verify_msg[client_verify_msg.len - P.mac_length ..][0..P.mac_length], + &client_verify_cleartext, + std.mem.toBytes(big(write_seq)) ++ client_verify_msg[0 .. 1 + 2] ++ int(u16, client_verify_cleartext.len), + nonce, + pv.app_cipher.client_write_key, + ); + const all_msgs = client_key_exchange_msg ++ client_change_cipher_spec_msg ++ client_verify_msg; + var all_msgs_vec = [_]std.posix.iovec_const{ + .{ .base = &all_msgs, .len = all_msgs.len }, + }; + try stream.writevAll(&all_msgs_vec); + }, + } + write_seq += 1; + pending_cipher_state = .application; + handshake_state = .finished; + }, + .certificate_verify => { + if (tls_version != .tls_1_3) return error.TlsUnexpectedMessage; + if (cipher_state != .handshake) return error.TlsUnexpectedMessage; + switch (handshake_state) { + .trust_chain_established => {}, + .certificate => return error.TlsCertificateNotVerified, + else => return error.TlsUnexpectedMessage, + } + switch (handshake_cipher) { + inline else => |*p| { + try main_cert_pub_key.verifySignature(&hsd, &.{ + " " ** 64 ++ "TLS 1.3, server CertificateVerify\x00", + &p.transcript_hash.peek(), + }); + p.transcript_hash.update(wrapped_handshake); + }, + } + handshake_state = .finished; + }, + .finished => { + if (cipher_state == .cleartext) return error.TlsUnexpectedMessage; + if (handshake_state != .finished) return error.TlsUnexpectedMessage; + // This message is to trick buggy proxies into behaving correctly. + const client_change_cipher_spec_msg = .{@intFromEnum(tls.ContentType.change_cipher_spec)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, tls.ChangeCipherSpecType, .{.change_cipher_spec}); + const app_cipher = app_cipher: switch (handshake_cipher) { + inline else => |*p, tag| switch (tls_version) { + .tls_1_3 => { + const pv = &p.version.tls_1_3; + const P = @TypeOf(p.*).A; + try hsd.ensure(P.Hmac.mac_length); const finished_digest = p.transcript_hash.peek(); p.transcript_hash.update(wrapped_handshake); - const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key); - if (!mem.eql(u8, &expected_server_verify_data, handshake)) - return error.TlsDecryptError; + const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, pv.server_finished_key); + if (!std.crypto.timing_safe.eql([P.Hmac.mac_length]u8, expected_server_verify_data, hsd.array(P.Hmac.mac_length).*)) return error.TlsDecryptError; const handshake_hash = p.transcript_hash.finalResult(); - const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key); - const out_cleartext = [_]u8{ - @intFromEnum(tls.HandshakeType.finished), - 0, 0, verify_data.len, // length - } ++ verify_data ++ [1]u8{@intFromEnum(tls.ContentType.handshake)}; + const verify_data = tls.hmac(P.Hmac, &handshake_hash, pv.client_finished_key); + const out_cleartext = .{@intFromEnum(tls.HandshakeType.finished)} ++ + array(u24, u8, verify_data) ++ + .{@intFromEnum(tls.ContentType.handshake)}; const wrapped_len = out_cleartext.len + P.AEAD.tag_length; - var finished_msg = [_]u8{ - @intFromEnum(tls.ContentType.application_data), - 0x03, 0x03, // legacy protocol version - 0, wrapped_len, // byte length of encrypted record - } ++ @as([wrapped_len]u8, undefined); + var finished_msg = .{@intFromEnum(tls.ContentType.application_data)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + array(u16, u8, @as([wrapped_len]u8, undefined)); - const ad = finished_msg[0..5]; - const ciphertext = finished_msg[5..][0..out_cleartext.len]; + const ad = finished_msg[0..tls.record_header_len]; + const ciphertext = finished_msg[tls.record_header_len..][0..out_cleartext.len]; const auth_tag = finished_msg[finished_msg.len - P.AEAD.tag_length ..]; - const nonce = p.client_handshake_iv; - P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key); + const nonce = pv.client_handshake_iv; + P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, pv.client_handshake_key); - const both_msgs = client_change_cipher_spec_msg ++ finished_msg; - var both_msgs_vec = [_]std.posix.iovec_const{.{ - .base = &both_msgs, - .len = both_msgs.len, - }}; - try stream.writevAll(&both_msgs_vec); + const all_msgs = client_change_cipher_spec_msg ++ finished_msg; + var all_msgs_vec = [_]std.posix.iovec_const{ + .{ .base = &all_msgs, .len = all_msgs.len }, + }; + try stream.writevAll(&all_msgs_vec); - const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); - const server_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); - break :c @unionInit(tls.ApplicationCipher, @tagName(tag), .{ + const client_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.master_secret, "s ap traffic", &handshake_hash, P.Hash.digest_length); + if (options.ssl_key_log_file) |key_log_file| logSecrets(key_log_file, .{ + .counter = key_seq, + .client_random = &client_hello_rand, + }, .{ + .SERVER_TRAFFIC_SECRET = &server_secret, + .CLIENT_TRAFFIC_SECRET = &client_secret, + }); + key_seq += 1; + break :app_cipher @unionInit(tls.ApplicationCipher, @tagName(tag), .{ .tls_1_3 = .{ .client_secret = client_secret, .server_secret = server_secret, .client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length), .server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length), .client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length), .server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length), - }); + } }); }, - }; - const leftover = d.rest(); - var client: Client = .{ - .read_seq = 0, - .write_seq = 0, - .partial_cleartext_idx = 0, - .partial_ciphertext_idx = 0, - .partial_ciphertext_end = @intCast(leftover.len), - .received_close_notify = false, - .application_cipher = app_cipher, - .partially_read_buffer = undefined, - }; - @memcpy(client.partially_read_buffer[0..leftover.len], leftover); - return client; - }, - else => { - return error.TlsUnexpectedMessage; - }, - } - if (ctd.eof()) break; + .tls_1_2 => { + const pv = &p.version.tls_1_2; + const P = @TypeOf(p.*).A; + try hsd.ensure(P.verify_data_length); + if (!std.crypto.timing_safe.eql([P.verify_data_length]u8, pv.expected_server_verify_data, hsd.array(P.verify_data_length).*)) return error.TlsDecryptError; + break :app_cipher @unionInit(tls.ApplicationCipher, @tagName(tag), .{ .tls_1_2 = pv.app_cipher }); + }, + else => unreachable, + }, + }; + const leftover = d.rest(); + var client: Client = .{ + .tls_version = tls_version, + .read_seq = switch (tls_version) { + .tls_1_3 => 0, + .tls_1_2 => read_seq, + else => unreachable, + }, + .write_seq = switch (tls_version) { + .tls_1_3 => 0, + .tls_1_2 => write_seq, + else => unreachable, + }, + .partial_cleartext_idx = 0, + .partial_ciphertext_idx = 0, + .partial_ciphertext_end = @intCast(leftover.len), + .received_close_notify = false, + .allow_truncation_attacks = false, + .application_cipher = app_cipher, + .partially_read_buffer = undefined, + .ssl_key_log = if (options.ssl_key_log_file) |key_log_file| .{ + .client_key_seq = key_seq, + .server_key_seq = key_seq, + .client_random = client_hello_rand, + .file = key_log_file, + } else null, + }; + @memcpy(client.partially_read_buffer[0..leftover.len], leftover); + return client; + }, + else => return error.TlsUnexpectedMessage, } + if (ctd.eof()) break; + cleartext_fragment_start = ctd.idx; }, - else => { - return error.TlsUnexpectedMessage; - }, + else => return error.TlsUnexpectedMessage, } + cleartext_fragment_start = 0; + cleartext_fragment_end = 0; } } /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. +/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. pub fn write(c: *Client, stream: anytype, bytes: []const u8) !usize { return writeEnd(c, stream, bytes, false); } @@ -749,7 +944,7 @@ pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !v } /// Sends TLS-encrypted data to `stream`, which must conform to `StreamInterface`. -/// Returns the number of plaintext bytes sent, which may be fewer than `bytes.len`. +/// Returns the number of cleartext bytes sent, which may be fewer than `bytes.len`. /// If `end` is true, then this function additionally sends a `close_notify` alert, /// which is necessary for the server to distinguish between a properly finished /// TLS session, or a truncation attack. @@ -813,62 +1008,126 @@ fn prepareCiphertextRecord( var iovec_end: usize = 0; var bytes_i: usize = 0; switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; - const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; - while (true) { - const encrypted_content_len: u16 = @intCast(@min( - @min(bytes.len - bytes_i, tls.max_ciphertext_inner_record_len), - ciphertext_buf.len -| - (close_notify_alert_reserved + overhead_len + ciphertext_end), - )); - if (encrypted_content_len == 0) return .{ - .iovec_end = iovec_end, - .ciphertext_end = ciphertext_end, - .overhead_len = overhead_len, - }; + inline else => |*p| switch (c.tls_version) { + .tls_1_3 => { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const overhead_len = tls.record_header_len + P.AEAD.tag_length + 1; + const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; + while (true) { + const encrypted_content_len: u16 = @min( + bytes.len - bytes_i, + tls.max_ciphertext_inner_record_len, + ciphertext_buf.len -| + (close_notify_alert_reserved + overhead_len + ciphertext_end), + ); + if (encrypted_content_len == 0) return .{ + .iovec_end = iovec_end, + .ciphertext_end = ciphertext_end, + .overhead_len = overhead_len, + }; - @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); - cleartext_buf[encrypted_content_len] = @intFromEnum(inner_content_type); - bytes_i += encrypted_content_len; - const ciphertext_len = encrypted_content_len + 1; - const cleartext = cleartext_buf[0..ciphertext_len]; + @memcpy(cleartext_buf[0..encrypted_content_len], bytes[bytes_i..][0..encrypted_content_len]); + cleartext_buf[encrypted_content_len] = @intFromEnum(inner_content_type); + bytes_i += encrypted_content_len; + const ciphertext_len = encrypted_content_len + 1; + const cleartext = cleartext_buf[0..ciphertext_len]; - const record_start = ciphertext_end; - const ad = ciphertext_buf[ciphertext_end..][0..5]; - ad.* = - [_]u8{@intFromEnum(tls.ContentType.application_data)} ++ - int2(@intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ - int2(ciphertext_len + P.AEAD.tag_length); - ciphertext_end += ad.len; - const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; - ciphertext_end += ciphertext_len; - const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; - ciphertext_end += auth_tag.len; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.client_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.write_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.write_seq))); - break :nonce @as(V, p.client_iv) ^ operand; - }; - c.write_seq += 1; // TODO send key_update on overflow - P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key); + const record_start = ciphertext_end; + const ad = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; + ad.* = .{@intFromEnum(tls.ContentType.application_data)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + int(u16, ciphertext_len + P.AEAD.tag_length); + ciphertext_end += ad.len; + const ciphertext = ciphertext_buf[ciphertext_end..][0..ciphertext_len]; + ciphertext_end += ciphertext_len; + const auth_tag = ciphertext_buf[ciphertext_end..][0..P.AEAD.tag_length]; + ciphertext_end += auth_tag.len; + const nonce = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.client_iv; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.write_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ std.mem.toBytes(big(c.write_seq)); + break :nonce @as(V, pv.client_iv) ^ operand; + }; + P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_key); + c.write_seq += 1; // TODO send key_update on overflow - const record = ciphertext_buf[record_start..ciphertext_end]; - iovecs[iovec_end] = .{ - .base = record.ptr, - .len = record.len, - }; - iovec_end += 1; - } + const record = ciphertext_buf[record_start..ciphertext_end]; + iovecs[iovec_end] = .{ + .base = record.ptr, + .len = record.len, + }; + iovec_end += 1; + } + }, + .tls_1_2 => { + const pv = &p.tls_1_2; + const P = @TypeOf(p.*); + const overhead_len = tls.record_header_len + P.record_iv_length + P.mac_length; + const close_notify_alert_reserved = tls.close_notify_alert.len + overhead_len; + while (true) { + const message_len: u16 = @min( + bytes.len - bytes_i, + tls.max_ciphertext_inner_record_len, + ciphertext_buf.len -| + (close_notify_alert_reserved + overhead_len + ciphertext_end), + ); + if (message_len == 0) return .{ + .iovec_end = iovec_end, + .ciphertext_end = ciphertext_end, + .overhead_len = overhead_len, + }; + + @memcpy(cleartext_buf[0..message_len], bytes[bytes_i..][0..message_len]); + bytes_i += message_len; + const cleartext = cleartext_buf[0..message_len]; + + const record_start = ciphertext_end; + const record_header = ciphertext_buf[ciphertext_end..][0..tls.record_header_len]; + ciphertext_end += tls.record_header_len; + record_header.* = .{@intFromEnum(inner_content_type)} ++ + int(u16, @intFromEnum(tls.ProtocolVersion.tls_1_2)) ++ + int(u16, P.record_iv_length + message_len + P.mac_length); + const ad = std.mem.toBytes(big(c.write_seq)) ++ record_header[0 .. 1 + 2] ++ int(u16, message_len); + const record_iv = ciphertext_buf[ciphertext_end..][0..P.record_iv_length]; + ciphertext_end += P.record_iv_length; + const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.client_write_IV ++ pv.client_salt; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.write_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(c.write_seq))); + break :nonce @as(V, pv.client_write_IV ++ pv.client_salt) ^ operand; + }; + record_iv.* = nonce[P.fixed_iv_length..].*; + const ciphertext = ciphertext_buf[ciphertext_end..][0..message_len]; + ciphertext_end += message_len; + const auth_tag = ciphertext_buf[ciphertext_end..][0..P.mac_length]; + ciphertext_end += P.mac_length; + P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, pv.client_write_key); + c.write_seq += 1; // TODO send key_update on overflow + + const record = ciphertext_buf[record_start..ciphertext_end]; + iovecs[iovec_end] = .{ + .base = record.ptr, + .len = record.len, + }; + iovec_end += 1; + } + }, + else => unreachable, }, } } @@ -990,7 +1249,7 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove // beginning of the buffer will be used for such purposes. const cleartext_buf_len = free_size - ciphertext_buf_len; - // Recoup `partially_read_buffer space`. This is necessary because it is assumed + // Recoup `partially_read_buffer` space. This is necessary because it is assumed // below that `frag0` is big enough to hold at least one record. limitedOverlapCopy(c.partially_read_buffer[0..c.partial_ciphertext_end], c.partial_ciphertext_idx); c.partial_ciphertext_end -= c.partial_ciphertext_idx; @@ -1105,164 +1364,211 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove in = 0; continue; } - switch (ct) { + const cleartext, const inner_ct: tls.ContentType = cleartext: switch (c.application_cipher) { + inline else => |*p| switch (c.tls_version) { + .tls_1_3 => { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const ad = frag[in - tls.record_header_len ..][0..tls.record_header_len]; + const ciphertext_len = record_len - P.AEAD.tag_length; + const ciphertext = frag[in..][0..ciphertext_len]; + in += ciphertext_len; + const auth_tag = frag[in..][0..P.AEAD.tag_length].*; + const nonce = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.server_iv; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.read_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ std.mem.toBytes(big(c.read_seq)); + break :nonce @as(V, pv.server_iv) ^ operand; + }; + const out_buf = vp.peek(); + const cleartext_buf = if (ciphertext.len <= out_buf.len) + out_buf + else + &cleartext_stack_buffer; + const cleartext = cleartext_buf[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_key) catch + return error.TlsBadRecordMac; + const msg = mem.trimRight(u8, cleartext, "\x00"); + break :cleartext .{ msg[0 .. msg.len - 1], @enumFromInt(msg[msg.len - 1]) }; + }, + .tls_1_2 => { + const pv = &p.tls_1_2; + const P = @TypeOf(p.*); + const message_len: u16 = record_len - P.record_iv_length - P.mac_length; + const ad = std.mem.toBytes(big(c.read_seq)) ++ + frag[in - tls.record_header_len ..][0 .. 1 + 2] ++ + std.mem.toBytes(big(message_len)); + const record_iv = frag[in..][0..P.record_iv_length].*; + in += P.record_iv_length; + const masked_read_seq = c.read_seq & + comptime std.math.shl(u64, std.math.maxInt(u64), 8 * P.record_iv_length); + const nonce: [P.AEAD.nonce_length]u8 = if (builtin.zig_backend == .stage2_x86_64 and + P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) + nonce: { + var nonce = pv.server_write_IV ++ record_iv; + const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); + std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ masked_read_seq, .big); + break :nonce nonce; + } else nonce: { + const V = @Vector(P.AEAD.nonce_length, u8); + const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); + const operand: V = pad ++ @as([8]u8, @bitCast(big(masked_read_seq))); + break :nonce @as(V, pv.server_write_IV ++ record_iv) ^ operand; + }; + const ciphertext = frag[in..][0..message_len]; + in += message_len; + const auth_tag = frag[in..][0..P.mac_length].*; + in += P.mac_length; + const out_buf = vp.peek(); + const cleartext_buf = if (message_len <= out_buf.len) + out_buf + else + &cleartext_stack_buffer; + const cleartext = cleartext_buf[0..ciphertext.len]; + P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, pv.server_write_key) catch + return error.TlsBadRecordMac; + break :cleartext .{ cleartext, ct }; + }, + else => unreachable, + }, + }; + c.read_seq = try std.math.add(u64, c.read_seq, 1); + switch (inner_ct) { .alert => { - if (in + 2 > frag.len) return error.TlsDecodeError; - const level: tls.AlertLevel = @enumFromInt(frag[in]); - const desc: tls.AlertDescription = @enumFromInt(frag[in + 1]); + if (cleartext.len != 2) return error.TlsDecodeError; + const level: tls.AlertLevel = @enumFromInt(cleartext[0]); + const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); + if (desc == .close_notify) { + c.received_close_notify = true; + c.partial_ciphertext_end = c.partial_ciphertext_idx; + return vp.total; + } _ = level; try desc.toError(); // TODO: handle server-side closures return error.TlsUnexpectedMessage; }, - .application_data => { - const cleartext = switch (c.application_cipher) { - inline else => |*p| c: { - const P = @TypeOf(p.*); - const ad = frag[in - 5 ..][0..5]; - const ciphertext_len = record_len - P.AEAD.tag_length; - const ciphertext = frag[in..][0..ciphertext_len]; - in += ciphertext_len; - const auth_tag = frag[in..][0..P.AEAD.tag_length].*; - const nonce = if (builtin.zig_backend == .stage2_x86_64 and - P.AEAD.nonce_length > comptime std.simd.suggestVectorLength(u8) orelse 1) - nonce: { - var nonce = p.server_iv; - const operand = std.mem.readInt(u64, nonce[nonce.len - 8 ..], .big); - std.mem.writeInt(u64, nonce[nonce.len - 8 ..], operand ^ c.read_seq, .big); - break :nonce nonce; - } else nonce: { - const V = @Vector(P.AEAD.nonce_length, u8); - const pad = [1]u8{0} ** (P.AEAD.nonce_length - 8); - const operand: V = pad ++ @as([8]u8, @bitCast(big(c.read_seq))); - break :nonce @as(V, p.server_iv) ^ operand; - }; - const out_buf = vp.peek(); - const cleartext_buf = if (ciphertext.len <= out_buf.len) - out_buf - else - &cleartext_stack_buffer; - const cleartext = cleartext_buf[0..ciphertext.len]; - P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_key) catch - return error.TlsBadRecordMac; - break :c mem.trimRight(u8, cleartext, "\x00"); - }, - }; - - c.read_seq = try std.math.add(u64, c.read_seq, 1); - - const inner_ct: tls.ContentType = @enumFromInt(cleartext[cleartext.len - 1]); - switch (inner_ct) { - .alert => { - const level: tls.AlertLevel = @enumFromInt(cleartext[0]); - const desc: tls.AlertDescription = @enumFromInt(cleartext[1]); - if (desc == .close_notify) { - c.received_close_notify = true; - c.partial_ciphertext_end = c.partial_ciphertext_idx; - return vp.total; - } - _ = level; - - try desc.toError(); - // TODO: handle server-side closures - return error.TlsUnexpectedMessage; - }, - .handshake => { - var ct_i: usize = 0; - while (true) { - const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); - ct_i += 1; - const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); - ct_i += 3; - const next_handshake_i = ct_i + handshake_len; - if (next_handshake_i > cleartext.len - 1) - return error.TlsBadLength; - const handshake = cleartext[ct_i..next_handshake_i]; - switch (handshake_type) { - .new_session_ticket => { - // This client implementation ignores new session tickets. + .handshake => { + var ct_i: usize = 0; + while (true) { + const handshake_type: tls.HandshakeType = @enumFromInt(cleartext[ct_i]); + ct_i += 1; + const handshake_len = mem.readInt(u24, cleartext[ct_i..][0..3], .big); + ct_i += 3; + const next_handshake_i = ct_i + handshake_len; + if (next_handshake_i > cleartext.len) + return error.TlsBadLength; + const handshake = cleartext[ct_i..next_handshake_i]; + switch (handshake_type) { + .new_session_ticket => { + // This client implementation ignores new session tickets. + }, + .key_update => { + switch (c.application_cipher) { + inline else => |*p| { + const pv = &p.tls_1_3; + const P = @TypeOf(p.*); + const server_secret = hkdfExpandLabel(P.Hkdf, pv.server_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ + .counter = key_log.serverCounter(), + .client_random = &key_log.client_random, + }, .{ + .SERVER_TRAFFIC_SECRET = &server_secret, + }); + pv.server_secret = server_secret; + pv.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); + pv.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); }, - .key_update => { + } + c.read_seq = 0; + + switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { + .update_requested => { switch (c.application_cipher) { inline else => |*p| { + const pv = &p.tls_1_3; const P = @TypeOf(p.*); - const server_secret = hkdfExpandLabel(P.Hkdf, p.server_secret, "traffic upd", "", P.Hash.digest_length); - p.server_secret = server_secret; - p.server_key = hkdfExpandLabel(P.Hkdf, server_secret, "key", "", P.AEAD.key_length); - p.server_iv = hkdfExpandLabel(P.Hkdf, server_secret, "iv", "", P.AEAD.nonce_length); + const client_secret = hkdfExpandLabel(P.Hkdf, pv.client_secret, "traffic upd", "", P.Hash.digest_length); + if (c.ssl_key_log) |*key_log| logSecrets(key_log.file, .{ + .counter = key_log.clientCounter(), + .client_random = &key_log.client_random, + }, .{ + .CLIENT_TRAFFIC_SECRET = &client_secret, + }); + pv.client_secret = client_secret; + pv.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); + pv.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); }, } - c.read_seq = 0; - - switch (@as(tls.KeyUpdateRequest, @enumFromInt(handshake[0]))) { - .update_requested => { - switch (c.application_cipher) { - inline else => |*p| { - const P = @TypeOf(p.*); - const client_secret = hkdfExpandLabel(P.Hkdf, p.client_secret, "traffic upd", "", P.Hash.digest_length); - p.client_secret = client_secret; - p.client_key = hkdfExpandLabel(P.Hkdf, client_secret, "key", "", P.AEAD.key_length); - p.client_iv = hkdfExpandLabel(P.Hkdf, client_secret, "iv", "", P.AEAD.nonce_length); - }, - } - c.write_seq = 0; - }, - .update_not_requested => {}, - _ => return error.TlsIllegalParameter, - } - }, - else => { - return error.TlsUnexpectedMessage; + c.write_seq = 0; }, + .update_not_requested => {}, + _ => return error.TlsIllegalParameter, } - ct_i = next_handshake_i; - if (ct_i >= cleartext.len - 1) break; - } - }, - .application_data => { - // Determine whether the output buffer or a stack - // buffer was used for storing the cleartext. - if (cleartext.ptr == &cleartext_stack_buffer) { - // Stack buffer was used, so we must copy to the output buffer. - const msg = cleartext[0 .. cleartext.len - 1]; - if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { - // We have already run out of room in iovecs. Continue - // appending to `partially_read_buffer`. - @memcpy( - c.partially_read_buffer[c.partial_ciphertext_idx..][0..msg.len], - msg, - ); - c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + msg.len); - } else { - const amt = vp.put(msg); - if (amt < msg.len) { - const rest = msg[amt..]; - c.partial_cleartext_idx = 0; - c.partial_ciphertext_idx = @intCast(rest.len); - @memcpy(c.partially_read_buffer[0..rest.len], rest); - } - } - } else { - // Output buffer was used directly which means no - // memory copying needs to occur, and we can move - // on to the next ciphertext record. - vp.next(cleartext.len - 1); - } - }, - else => { - return error.TlsUnexpectedMessage; - }, + }, + else => { + return error.TlsUnexpectedMessage; + }, + } + ct_i = next_handshake_i; + if (ct_i >= cleartext.len) break; } }, - else => { - return error.TlsUnexpectedMessage; + .application_data => { + // Determine whether the output buffer or a stack + // buffer was used for storing the cleartext. + if (cleartext.ptr == &cleartext_stack_buffer) { + // Stack buffer was used, so we must copy to the output buffer. + if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { + // We have already run out of room in iovecs. Continue + // appending to `partially_read_buffer`. + @memcpy( + c.partially_read_buffer[c.partial_ciphertext_idx..][0..cleartext.len], + cleartext, + ); + c.partial_ciphertext_idx = @intCast(c.partial_ciphertext_idx + cleartext.len); + } else { + const amt = vp.put(cleartext); + if (amt < cleartext.len) { + const rest = cleartext[amt..]; + c.partial_cleartext_idx = 0; + c.partial_ciphertext_idx = @intCast(rest.len); + @memcpy(c.partially_read_buffer[0..rest.len], rest); + } + } + } else { + // Output buffer was used directly which means no + // memory copying needs to occur, and we can move + // on to the next ciphertext record. + vp.next(cleartext.len); + } }, + else => return error.TlsUnexpectedMessage, } in = end; } } +fn logSecrets(key_log_file: std.fs.File, context: anytype, secrets: anytype) void { + const locked = if (key_log_file.lock(.exclusive)) |_| true else |_| false; + defer if (locked) key_log_file.unlock(); + key_log_file.seekFromEnd(0) catch {}; + inline for (@typeInfo(@TypeOf(secrets)).@"struct".fields) |field| key_log_file.writer().print("{s}" ++ + (if (@hasField(@TypeOf(context), "counter")) "_{d}" else "") ++ " {} {}\n", .{field.name} ++ + (if (@hasField(@TypeOf(context), "counter")) .{context.counter} else .{}) ++ .{ + std.fmt.fmtSliceHexLower(context.client_random), + std.fmt.fmtSliceHexLower(@field(secrets, field.name)), + }) catch {}; +} + fn finishRead(c: *Client, frag: []const u8, in: usize, out: usize) usize { const saved_buf = frag[in..]; if (c.partial_ciphertext_idx > c.partial_cleartext_idx) { @@ -1326,6 +1632,86 @@ inline fn big(x: anytype) @TypeOf(x) { }; } +const KeyShare = struct { + ml_kem768_kp: crypto.kem.ml_kem.MLKem768.KeyPair, + secp256r1_kp: crypto.sign.ecdsa.EcdsaP256Sha256.KeyPair, + secp384r1_kp: crypto.sign.ecdsa.EcdsaP384Sha384.KeyPair, + x25519_kp: crypto.dh.X25519.KeyPair, + sk_buf: [sk_max_len]u8, + sk_len: std.math.IntFittingRange(0, sk_max_len), + + const sk_max_len = @max( + crypto.dh.X25519.shared_length + crypto.kem.ml_kem.MLKem768.shared_length, + crypto.ecc.P256.scalar.encoded_length, + crypto.ecc.P384.scalar.encoded_length, + crypto.dh.X25519.shared_length, + ); + + fn init(seed: [112]u8) error{IdentityElement}!KeyShare { + return .{ + .ml_kem768_kp = try .create(null), + .secp256r1_kp = try .create(seed[0..32].*), + .secp384r1_kp = try .create(seed[32..80].*), + .x25519_kp = try .create(seed[80..112].*), + .sk_buf = undefined, + .sk_len = 0, + }; + } + + fn exchange( + ks: *KeyShare, + named_group: tls.NamedGroup, + server_pub_key: []const u8, + ) error{ TlsIllegalParameter, TlsDecryptFailure }!void { + switch (named_group) { + .x25519_ml_kem768 => { + const hksl = crypto.kem.ml_kem.MLKem768.ciphertext_length; + const xksl = hksl + crypto.dh.X25519.public_length; + if (server_pub_key.len != xksl) return error.TlsIllegalParameter; + + const hsk = ks.ml_kem768_kp.secret_key.decaps(server_pub_key[0..hksl]) catch + return error.TlsDecryptFailure; + const xsk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[hksl..xksl].*) catch + return error.TlsDecryptFailure; + @memcpy(ks.sk_buf[0..hsk.len], &hsk); + @memcpy(ks.sk_buf[hsk.len..][0..xsk.len], &xsk); + ks.sk_len = hsk.len + xsk.len; + }, + .secp256r1 => { + const PublicKey = crypto.sign.ecdsa.EcdsaP256Sha256.PublicKey; + const pk = PublicKey.fromSec1(server_pub_key) catch return error.TlsDecryptFailure; + const mul = pk.p.mulPublic(ks.secp256r1_kp.secret_key.bytes, .big) catch + return error.TlsDecryptFailure; + const sk = mul.affineCoordinates().x.toBytes(.big); + @memcpy(ks.sk_buf[0..sk.len], &sk); + ks.sk_len = sk.len; + }, + .secp384r1 => { + const PublicKey = crypto.sign.ecdsa.EcdsaP384Sha384.PublicKey; + const pk = PublicKey.fromSec1(server_pub_key) catch return error.TlsDecryptFailure; + const mul = pk.p.mulPublic(ks.secp384r1_kp.secret_key.bytes, .big) catch + return error.TlsDecryptFailure; + const sk = mul.affineCoordinates().x.toBytes(.big); + @memcpy(ks.sk_buf[0..sk.len], &sk); + ks.sk_len = sk.len; + }, + .x25519 => { + const ksl = crypto.dh.X25519.public_length; + if (server_pub_key.len != ksl) return error.TlsIllegalParameter; + const sk = crypto.dh.X25519.scalarmult(ks.x25519_kp.secret_key, server_pub_key[0..ksl].*) catch + return error.TlsDecryptFailure; + @memcpy(ks.sk_buf[0..sk.len], &sk); + ks.sk_len = sk.len; + }, + else => return error.TlsIllegalParameter, + } + } + + fn getSharedSecret(ks: *const KeyShare) ?[]const u8 { + return if (ks.sk_len > 0) ks.sk_buf[0..ks.sk_len] else null; + } +}; + fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { return switch (scheme) { .ecdsa_secp256r1_sha256 => crypto.sign.ecdsa.EcdsaP256Sha256, @@ -1334,11 +1720,20 @@ fn SchemeEcdsa(comptime scheme: tls.SignatureScheme) type { }; } -fn SchemeHash(comptime scheme: tls.SignatureScheme) type { +fn SchemeRsa(comptime scheme: tls.SignatureScheme) type { return switch (scheme) { - .rsa_pss_rsae_sha256 => crypto.hash.sha2.Sha256, - .rsa_pss_rsae_sha384 => crypto.hash.sha2.Sha384, - .rsa_pss_rsae_sha512 => crypto.hash.sha2.Sha512, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + .rsa_pkcs1_sha1, + => Certificate.rsa.PKCS1v1_5Signature, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .rsa_pss_pss_sha256, + .rsa_pss_pss_sha384, + .rsa_pss_pss_sha512, + => Certificate.rsa.PSSSignature, else => @compileError("bad scheme"), }; } @@ -1350,6 +1745,146 @@ fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type { }; } +fn SchemeHash(comptime scheme: tls.SignatureScheme) type { + return switch (scheme) { + .rsa_pkcs1_sha256, + .ecdsa_secp256r1_sha256, + .rsa_pss_rsae_sha256, + .rsa_pss_pss_sha256, + => crypto.hash.sha2.Sha256, + .rsa_pkcs1_sha384, + .ecdsa_secp384r1_sha384, + .rsa_pss_rsae_sha384, + .rsa_pss_pss_sha384, + => crypto.hash.sha2.Sha384, + .rsa_pkcs1_sha512, + .ecdsa_secp521r1_sha512, + .rsa_pss_rsae_sha512, + .rsa_pss_pss_sha512, + => crypto.hash.sha2.Sha512, + .rsa_pkcs1_sha1, + .ecdsa_sha1, + => crypto.hash.Sha1, + else => @compileError("bad scheme"), + }; +} + +const CertificatePublicKey = struct { + algo: Certificate.AlgorithmCategory, + buf: [600]u8, + len: u16, + + fn init( + cert_pub_key: *CertificatePublicKey, + algo: Certificate.AlgorithmCategory, + pub_key: []const u8, + ) error{CertificatePublicKeyInvalid}!void { + if (pub_key.len > cert_pub_key.buf.len) return error.CertificatePublicKeyInvalid; + cert_pub_key.algo = algo; + @memcpy(cert_pub_key.buf[0..pub_key.len], pub_key); + cert_pub_key.len = @intCast(pub_key.len); + } + + const VerifyError = error{ TlsDecodeError, TlsBadSignatureScheme, InvalidEncoding } || + // ecdsa + crypto.errors.EncodingError || + crypto.errors.NotSquareError || + crypto.errors.NonCanonicalError || + SchemeEcdsa(.ecdsa_secp256r1_sha256).Signature.VerifyError || + SchemeEcdsa(.ecdsa_secp384r1_sha384).Signature.VerifyError || + // rsa + error{TlsBadRsaSignatureBitCount} || + Certificate.rsa.PublicKey.ParseDerError || + Certificate.rsa.PublicKey.FromBytesError || + Certificate.rsa.PSSSignature.VerifyError || + Certificate.rsa.PKCS1v1_5Signature.VerifyError || + // eddsa + SchemeEddsa(.ed25519).Signature.VerifyError; + + fn verifySignature( + cert_pub_key: *const CertificatePublicKey, + sigd: *tls.Decoder, + msg: []const []const u8, + ) VerifyError!void { + const pub_key = cert_pub_key.buf[0..cert_pub_key.len]; + + try sigd.ensure(2 + 2); + const scheme = sigd.decode(tls.SignatureScheme); + const sig_len = sigd.decode(u16); + try sigd.ensure(sig_len); + const encoded_sig = sigd.slice(sig_len); + + if (cert_pub_key.algo != @as(Certificate.AlgorithmCategory, switch (scheme) { + .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => .X9_62_id_ecPublicKey, + .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .rsa_pkcs1_sha1, + => .rsaEncryption, + .rsa_pss_pss_sha256, + .rsa_pss_pss_sha384, + .rsa_pss_pss_sha512, + => .rsassa_pss, + else => return error.TlsBadSignatureScheme, + })) return error.TlsBadSignatureScheme; + + switch (scheme) { + inline .ecdsa_secp256r1_sha256, + .ecdsa_secp384r1_sha384, + => |comptime_scheme| { + const Ecdsa = SchemeEcdsa(comptime_scheme); + const sig = try Ecdsa.Signature.fromDer(encoded_sig); + const key = try Ecdsa.PublicKey.fromSec1(pub_key); + var ver = try sig.verifier(key); + for (msg) |part| ver.update(part); + try ver.verify(); + }, + inline .rsa_pkcs1_sha256, + .rsa_pkcs1_sha384, + .rsa_pkcs1_sha512, + .rsa_pss_rsae_sha256, + .rsa_pss_rsae_sha384, + .rsa_pss_rsae_sha512, + .rsa_pss_pss_sha256, + .rsa_pss_pss_sha384, + .rsa_pss_pss_sha512, + .rsa_pkcs1_sha1, + => |comptime_scheme| { + const RsaSignature = SchemeRsa(comptime_scheme); + const Hash = SchemeHash(comptime_scheme); + const PublicKey = Certificate.rsa.PublicKey; + const components = try PublicKey.parseDer(pub_key); + const exponent = components.exponent; + const modulus = components.modulus; + switch (modulus.len) { + inline 128, 256, 384, 512 => |modulus_len| { + const key: PublicKey = try .fromBytes(exponent, modulus); + const sig = RsaSignature.fromBytes(modulus_len, encoded_sig); + try RsaSignature.concatVerify(modulus_len, sig, msg, key, Hash); + }, + else => return error.TlsBadRsaSignatureBitCount, + } + }, + inline .ed25519 => |comptime_scheme| { + const Eddsa = SchemeEddsa(comptime_scheme); + if (encoded_sig.len != Eddsa.Signature.encoded_length) return error.InvalidEncoding; + const sig = Eddsa.Signature.fromBytes(encoded_sig[0..Eddsa.Signature.encoded_length].*); + if (pub_key.len != Eddsa.PublicKey.encoded_length) return error.InvalidEncoding; + const key = try Eddsa.PublicKey.fromBytes(pub_key[0..Eddsa.PublicKey.encoded_length].*); + var ver = try sig.verifier(key); + for (msg) |part| ver.update(part); + try ver.verify(); + }, + else => unreachable, + } + } +}; + /// Abstraction for sending multiple byte buffers to a slice of iovecs. const VecPut = struct { iovecs: []const std.posix.iovec, @@ -1447,20 +1982,26 @@ fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec { /// aes128-gcm: 138 MiB/s /// aes256-gcm: 120 MiB/s const cipher_suites = if (crypto.core.aes.has_hardware_support) - enum_array(tls.CipherSuite, &.{ + array(u16, tls.CipherSuite, .{ .AEGIS_128L_SHA256, .AEGIS_256_SHA512, .AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, .AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, .CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, }) else - enum_array(tls.CipherSuite, &.{ + array(u16, tls.CipherSuite, .{ .CHACHA20_POLY1305_SHA256, + .ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, .AEGIS_128L_SHA256, .AEGIS_256_SHA512, .AES_128_GCM_SHA256, + .ECDHE_RSA_WITH_AES_128_GCM_SHA256, .AES_256_GCM_SHA384, + .ECDHE_RSA_WITH_AES_256_GCM_SHA384, }); test { diff --git a/lib/std/http/Client.zig b/lib/std/http/Client.zig index 6e95995ee0..9dcf7b5693 100644 --- a/lib/std/http/Client.zig +++ b/lib/std/http/Client.zig @@ -388,6 +388,7 @@ pub const Connection = struct { // try to cleanly close the TLS connection, for any server that cares. _ = conn.tls_client.writeEnd(conn.stream, "", true) catch {}; + if (conn.tls_client.ssl_key_log) |key_log| key_log.file.close(); allocator.destroy(conn.tls_client); } @@ -566,7 +567,7 @@ pub const Response = struct { .reason = undefined, .version = undefined, .keep_alive = false, - .parser = proto.HeadersParser.init(&header_buffer), + .parser = .init(&header_buffer), }; @memcpy(header_buffer[0..response_bytes.len], response_bytes); @@ -610,7 +611,7 @@ pub const Response = struct { } pub fn iterateHeaders(r: Response) http.HeaderIterator { - return http.HeaderIterator.init(r.parser.get()); + return .init(r.parser.get()); } test iterateHeaders { @@ -628,7 +629,7 @@ pub const Response = struct { .reason = undefined, .version = undefined, .keep_alive = false, - .parser = proto.HeadersParser.init(&header_buffer), + .parser = .init(&header_buffer), }; @memcpy(header_buffer[0..response_bytes.len], response_bytes); @@ -771,7 +772,7 @@ pub const Request = struct { req.client.connection_pool.release(req.client.allocator, req.connection.?); req.connection = null; - var server_header = std.heap.FixedBufferAllocator.init(req.response.parser.header_bytes_buffer); + var server_header: std.heap.FixedBufferAllocator = .init(req.response.parser.header_bytes_buffer); defer req.response.parser.header_bytes_buffer = server_header.buffer[server_header.end_index..]; const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); @@ -1354,7 +1355,27 @@ pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connec conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client); errdefer client.allocator.destroy(conn.data.tls_client); - conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed; + const ssl_key_log_file: ?std.fs.File = if (std.options.http_enable_ssl_key_log_file) ssl_key_log_file: { + const ssl_key_log_path = std.process.getEnvVarOwned(client.allocator, "SSLKEYLOGFILE") catch |err| switch (err) { + error.EnvironmentVariableNotFound, error.InvalidWtf8 => break :ssl_key_log_file null, + error.OutOfMemory => return error.OutOfMemory, + }; + defer client.allocator.free(ssl_key_log_path); + break :ssl_key_log_file std.fs.cwd().createFile(ssl_key_log_path, .{ + .truncate = false, + .mode = switch (builtin.os.tag) { + .windows, .wasi => 0, + else => 0o600, + }, + }) catch null; + } else null; + errdefer if (ssl_key_log_file) |key_log_file| key_log_file.close(); + + conn.data.tls_client.* = std.crypto.tls.Client.init(stream, .{ + .host = .{ .explicit = host }, + .ca = .{ .bundle = client.ca_bundle }, + .ssl_key_log_file = ssl_key_log_file, + }) catch return error.TlsInitializationFailed; // This is appropriate for HTTPS because the HTTP headers contain // the content length which is used to detect truncation attacks. conn.data.tls_client.allow_truncation_attacks = true; @@ -1620,7 +1641,7 @@ pub fn open( } } - var server_header = std.heap.FixedBufferAllocator.init(options.server_header_buffer); + var server_header: std.heap.FixedBufferAllocator = .init(options.server_header_buffer); const protocol, const valid_uri = try validateUri(uri, server_header.allocator()); if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .acquire)) { @@ -1654,7 +1675,7 @@ pub fn open( .status = undefined, .reason = undefined, .keep_alive = undefined, - .parser = proto.HeadersParser.init(server_header.buffer[server_header.end_index..]), + .parser = .init(server_header.buffer[server_header.end_index..]), }, .headers = options.headers, .extra_headers = options.extra_headers, diff --git a/lib/std/http/protocol.zig b/lib/std/http/protocol.zig index 78511f435d..c56d3a24a1 100644 --- a/lib/std/http/protocol.zig +++ b/lib/std/http/protocol.zig @@ -172,7 +172,13 @@ pub const HeadersParser = struct { const data_avail = r.next_chunk_length; if (skip) { - try conn.fill(); + conn.fill() catch |err| switch (err) { + error.EndOfStream => { + r.done = true; + return 0; + }, + else => |e| return e, + }; const nread = @min(conn.peek().len, data_avail); conn.drop(@intCast(nread)); @@ -196,7 +202,13 @@ pub const HeadersParser = struct { } }, .chunk_data_suffix, .chunk_data_suffix_r, .chunk_head_size, .chunk_head_ext, .chunk_head_r => { - try conn.fill(); + conn.fill() catch |err| switch (err) { + error.EndOfStream => { + r.done = true; + return 0; + }, + else => |e| return e, + }; const i = r.findChunkedLen(conn.peek()); conn.drop(@intCast(i)); @@ -226,7 +238,13 @@ pub const HeadersParser = struct { const out_avail = buffer.len - out_index; if (skip) { - try conn.fill(); + conn.fill() catch |err| switch (err) { + error.EndOfStream => { + r.done = true; + return 0; + }, + else => |e| return e, + }; const nread = @min(conn.peek().len, data_avail); conn.drop(@intCast(nread)); diff --git a/lib/std/std.zig b/lib/std/std.zig index 6dbb4c0843..cc61111746 100644 --- a/lib/std/std.zig +++ b/lib/std/std.zig @@ -146,6 +146,11 @@ pub const Options = struct { /// make a HTTPS connection. http_disable_tls: bool = false, + /// This enables `std.http.Client` to log ssl secrets to the file specified by the SSLKEYLOGFILE + /// env var. Creating such a log file allows other programs with access to that file to decrypt + /// all `std.http.Client` traffic made by this program. + http_enable_ssl_key_log_file: bool = @import("builtin").mode == .Debug, + side_channels_mitigations: crypto.SideChannelsMitigations = crypto.default_side_channels_mitigations, };