From 57313b789e870e8c385aa240e2d93d4a50550e65 Mon Sep 17 00:00:00 2001 From: Gabriel Uehlein Date: Sat, 23 Nov 2024 13:54:57 -0500 Subject: [PATCH 1/5] Make @enumFromInt do range checks on non-exhaustive enums' tag types Fixes #21946 --- src/Sema.zig | 32 ++++++++++++++++--- ..._int_narrows_to_tag_type_nonexhaustive.zig | 27 ++++++++++++++++ ...valid_nonexhaustive_enum_integer_value.zig | 19 +++++++++++ 3 files changed, 73 insertions(+), 5 deletions(-) create mode 100644 test/cases/enum_from_int_narrows_to_tag_type_nonexhaustive.zig create mode 100644 test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig diff --git a/src/Sema.zig b/src/Sema.zig index 0424970f18..26e3ce535e 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -9093,12 +9093,34 @@ fn zirEnumFromInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError try sema.requireRuntimeBlock(block, src, operand_src); const result = try block.addTyOp(.intcast, dest_ty, operand); - if (block.wantSafety() and !dest_ty.isNonexhaustiveEnum(zcu) and - zcu.backendSupportsFeature(.is_named_enum_value)) - { - const ok = try block.addUnOp(.is_named_enum_value, result); - try sema.addSafetyCheck(block, src, ok, .invalid_enum_value); + if (block.wantSafety()) { + if (dest_ty.isNonexhaustiveEnum(zcu)) { + const dest_int = dest_ty.intTagType(zcu); + const operand_ty = sema.typeOf(operand); + const dest_int_info = dest_int.intInfo(zcu); + if (operand_ty.toIntern() == .comptime_int_type or operand_ty.intInfo(zcu).bits > dest_int_info.bits or operand_ty.intInfo(zcu).signedness != dest_int_info.signedness) { + const max_int_val = try dest_int.maxIntScalar(pt, operand_ty); + const max_int = Air.internedToRef(max_int_val.toIntern()); + // operand <= max_int + const le_check = try block.addBinOp(.cmp_lte, operand, max_int); + if (operand_ty.intInfo(zcu).signedness == .signed) { + const min_int_val = try dest_int.minIntScalar(pt, operand_ty); + const min_int = Air.internedToRef(min_int_val.toIntern()); + // operand >= min_int + const ge_check = try block.addBinOp(.cmp_gte, operand, min_int); + // min_int <= operand <= max_int + const in_range_check = try block.addBinOp(.bool_and, le_check, ge_check); + try sema.addSafetyCheck(block, src, in_range_check, .invalid_enum_value); + } else { + try sema.addSafetyCheck(block, src, le_check, .invalid_enum_value); + } + } + } else if (zcu.backendSupportsFeature(.is_named_enum_value)) { + const ok = try block.addUnOp(.is_named_enum_value, result); + try sema.addSafetyCheck(block, src, ok, .invalid_enum_value); + } } + return result; } diff --git a/test/cases/enum_from_int_narrows_to_tag_type_nonexhaustive.zig b/test/cases/enum_from_int_narrows_to_tag_type_nonexhaustive.zig new file mode 100644 index 0000000000..737896f282 --- /dev/null +++ b/test/cases/enum_from_int_narrows_to_tag_type_nonexhaustive.zig @@ -0,0 +1,27 @@ +const std = @import("std"); + +const SignedWithVariants = enum(i4) { a, b, _ }; + +const UnsignedWithVariants = enum(u4) { a, b, _ }; + +const SignedEmpty = enum(i6) { _ }; + +const UnsignedEmpty = enum(u6) { _ }; + +pub fn main() void { + inline for (.{ SignedWithVariants, UnsignedWithVariants, SignedEmpty, UnsignedEmpty }) |EnumTy| { + const TagType = @typeInfo(EnumTy).@"enum".tag_type; + var v: isize = std.math.minInt(TagType); + while (v < std.math.maxInt(TagType)) : (v += 1) { + const variant = @as(EnumTy, @enumFromInt(v)); + assert(@as(@TypeOf(v), @intCast(@intFromEnum(variant))) == v); + } + } +} + +fn assert(ok: bool) void { + if (!ok) unreachable; +} + +// run +// backend=stage2,llvm diff --git a/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig b/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig new file mode 100644 index 0000000000..4597c237b2 --- /dev/null +++ b/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig @@ -0,0 +1,19 @@ +const std = @import("std"); + +pub fn panic(message: []const u8, _: ?*std.builtin.StackTrace, _: ?usize) noreturn { + if (std.mem.eql(u8, message, "invalid enum value")) { + std.process.exit(0); + } + std.process.exit(1); +} + +pub fn main() void { + const E = enum(u4) { _ }; + var invalid: u16 = 16; + _ = &invalid; + std.mem.doNotOptimizeAway(@as(E, @enumFromInt(invalid))); + std.process.exit(1); +} + +// run +// backend=stage2,llvm From adabfa2c9df04273bf9d27427b6e56cfd9e6f6a5 Mon Sep 17 00:00:00 2001 From: Gabriel Uehlein Date: Sat, 23 Nov 2024 14:01:39 -0500 Subject: [PATCH 2/5] Remove comptime_int check in zirEnumFromInt --- src/Sema.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Sema.zig b/src/Sema.zig index 26e3ce535e..40ca461a53 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -9098,7 +9098,7 @@ fn zirEnumFromInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const dest_int = dest_ty.intTagType(zcu); const operand_ty = sema.typeOf(operand); const dest_int_info = dest_int.intInfo(zcu); - if (operand_ty.toIntern() == .comptime_int_type or operand_ty.intInfo(zcu).bits > dest_int_info.bits or operand_ty.intInfo(zcu).signedness != dest_int_info.signedness) { + if (operand_ty.intInfo(zcu).bits > dest_int_info.bits or operand_ty.intInfo(zcu).signedness != dest_int_info.signedness) { const max_int_val = try dest_int.maxIntScalar(pt, operand_ty); const max_int = Air.internedToRef(max_int_val.toIntern()); // operand <= max_int From 44c1e35c14feea9089987d8662e8e4012f94aeb3 Mon Sep 17 00:00:00 2001 From: Gabriel Uehlein Date: Sun, 24 Nov 2024 14:47:29 -0500 Subject: [PATCH 3/5] Safer (in the compiler) checks + don't use doNotOptimizeAway in safety test --- src/Sema.zig | 58 +++++++++++++------ ..._int_narrows_to_tag_type_nonexhaustive.zig | 3 + ...valid_nonexhaustive_enum_integer_value.zig | 3 +- 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/src/Sema.zig b/src/Sema.zig index 40ca461a53..ceaff910ba 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -9095,32 +9095,56 @@ fn zirEnumFromInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const result = try block.addTyOp(.intcast, dest_ty, operand); if (block.wantSafety()) { if (dest_ty.isNonexhaustiveEnum(zcu)) { - const dest_int = dest_ty.intTagType(zcu); + // these checks are similar to those generated by `intCast`, + // except that `invalid_enum_value` is the panic reason for all + // failures. const operand_ty = sema.typeOf(operand); - const dest_int_info = dest_int.intInfo(zcu); - if (operand_ty.intInfo(zcu).bits > dest_int_info.bits or operand_ty.intInfo(zcu).signedness != dest_int_info.signedness) { - const max_int_val = try dest_int.maxIntScalar(pt, operand_ty); - const max_int = Air.internedToRef(max_int_val.toIntern()); - // operand <= max_int - const le_check = try block.addBinOp(.cmp_lte, operand, max_int); - if (operand_ty.intInfo(zcu).signedness == .signed) { - const min_int_val = try dest_int.minIntScalar(pt, operand_ty); - const min_int = Air.internedToRef(min_int_val.toIntern()); - // operand >= min_int - const ge_check = try block.addBinOp(.cmp_gte, operand, min_int); - // min_int <= operand <= max_int - const in_range_check = try block.addBinOp(.bool_and, le_check, ge_check); - try sema.addSafetyCheck(block, src, in_range_check, .invalid_enum_value); + const operand_int_info = sema.typeOf(operand).intInfo(zcu); + const dest_tag_ty = dest_ty.intTagType(zcu); + const dest_int_info = dest_tag_ty.intInfo(zcu); + var ok: ?Air.Inst.Ref = null; + if (operand_int_info.bits > dest_int_info.bits) { + // narrowing cast; operand <= dest_max_int required, and operand >= 0 may be + // required if operand is signed and destination int is unsigned + const dest_max_int = Air.internedToRef((try dest_tag_ty.maxIntScalar(pt, operand_ty)).toIntern()); + // operand <= maxInt(dest_tag_ty) + const le_check = try block.addBinOp(.cmp_lte, operand, dest_max_int); + if (operand_int_info.signedness == .signed) { + const dest_min_int = Air.internedToRef((try dest_tag_ty.minIntScalar(pt, operand_ty)).toIntern()); + const ge_zero_check = try block.addBinOp(.cmp_gte, operand, dest_min_int); + // operand >= minInt(dest_tag_ty) and operand <= maxInt(dest_tag_ty) + ok = try block.addBinOp(.bool_and, le_check, ge_zero_check); } else { - try sema.addSafetyCheck(block, src, le_check, .invalid_enum_value); + ok = le_check; + } + } else if (operand_int_info.bits == dest_int_info.bits) { + // checks are only needed here if the operand type's sign differs + // from the destination type's tag's sign. + if (operand_int_info.signedness == .unsigned and dest_int_info.signedness == .signed) { + const dest_max_int = try dest_tag_ty.maxIntScalar(pt, operand_ty); + // operand <= maxInt(dest_tag_ty) + ok = try block.addBinOp(.cmp_lte, operand, Air.internedToRef(dest_max_int.toIntern())); + } else if (operand_int_info.signedness == .signed and dest_int_info.signedness == .unsigned) { + const zero = try pt.intValue_i64(operand_ty, 0); + // operand >= 0 + ok = try block.addBinOp(.cmp_gte, operand, Air.internedToRef(zero.toIntern())); + } // else => operand and destination int are the same type; no checks needed + } else { + // extending cast; no checks needed unless the operand is signed + // and the destination is unsigned + if (operand_int_info.signedness == .signed and dest_int_info.signedness == .unsigned) { + const zero = try pt.intValue_i64(operand_ty, 0); + // operand >= 0 + ok = try block.addBinOp(.cmp_gte, operand, Air.internedToRef(zero.toIntern())); } } + if (ok) |check| + try sema.addSafetyCheck(block, src, check, .invalid_enum_value); } else if (zcu.backendSupportsFeature(.is_named_enum_value)) { const ok = try block.addUnOp(.is_named_enum_value, result); try sema.addSafetyCheck(block, src, ok, .invalid_enum_value); } } - return result; } diff --git a/test/cases/enum_from_int_narrows_to_tag_type_nonexhaustive.zig b/test/cases/enum_from_int_narrows_to_tag_type_nonexhaustive.zig index 737896f282..c404a8e282 100644 --- a/test/cases/enum_from_int_narrows_to_tag_type_nonexhaustive.zig +++ b/test/cases/enum_from_int_narrows_to_tag_type_nonexhaustive.zig @@ -16,6 +16,9 @@ pub fn main() void { const variant = @as(EnumTy, @enumFromInt(v)); assert(@as(@TypeOf(v), @intCast(@intFromEnum(variant))) == v); } + const max = std.math.maxInt(TagType); + const max_variant = @as(EnumTy, @enumFromInt(max)); + assert(@as(@TypeOf(max), @intCast(@intFromEnum(max_variant))) == max); } } diff --git a/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig b/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig index 4597c237b2..f0b7c1e829 100644 --- a/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig +++ b/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig @@ -8,10 +8,11 @@ pub fn panic(message: []const u8, _: ?*std.builtin.StackTrace, _: ?usize) noretu } pub fn main() void { + @setRuntimeSafety(true); const E = enum(u4) { _ }; var invalid: u16 = 16; _ = &invalid; - std.mem.doNotOptimizeAway(@as(E, @enumFromInt(invalid))); + _ = @as(E, @enumFromInt(invalid)); std.process.exit(1); } From a5d6c6e16cb28f5003f2619e93e05ae4eda18764 Mon Sep 17 00:00:00 2001 From: Gabriel Uehlein Date: Sun, 24 Nov 2024 23:10:36 -0500 Subject: [PATCH 4/5] Add target=native to invalid_nonexhaustive_enum_integer_value.zig Selfhosted WASM turns panic into unreachable (I think?) so this commit disables testing stage2 --- test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig b/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig index f0b7c1e829..ae0c4c1771 100644 --- a/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig +++ b/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig @@ -8,7 +8,6 @@ pub fn panic(message: []const u8, _: ?*std.builtin.StackTrace, _: ?usize) noretu } pub fn main() void { - @setRuntimeSafety(true); const E = enum(u4) { _ }; var invalid: u16 = 16; _ = &invalid; @@ -18,3 +17,4 @@ pub fn main() void { // run // backend=stage2,llvm +// target=native From 990588eb76fab3950c13b47e4dd14baf08d520b8 Mon Sep 17 00:00:00 2001 From: Gabriel Uehlein Date: Sun, 24 Nov 2024 23:20:11 -0500 Subject: [PATCH 5/5] Remove stage2 testing of invalid_nonexhaustive_enum_integer_value.zig This makes this test like the other safety tests, as they don't test stage2 --- test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig b/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig index ae0c4c1771..4c14ce2869 100644 --- a/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig +++ b/test/cases/safety/invalid_nonexhaustive_enum_integer_value.zig @@ -16,5 +16,5 @@ pub fn main() void { } // run -// backend=stage2,llvm +// backend=llvm // target=native