diff --git a/src/all_types.hpp b/src/all_types.hpp index ef159986a1..afe8bd0675 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1680,6 +1680,7 @@ enum PanicMsgId { PanicMsgIdResumedAnAwaitingFn, PanicMsgIdFrameTooSmall, PanicMsgIdResumedFnPendingAwait, + PanicMsgIdBadNoAsyncCall, PanicMsgIdCount, }; diff --git a/src/codegen.cpp b/src/codegen.cpp index 03c253ad48..6c03be32c3 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -923,6 +923,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) { return buf_create_from_str("frame too small"); case PanicMsgIdResumedFnPendingAwait: return buf_create_from_str("resumed an async function which can only be awaited"); + case PanicMsgIdBadNoAsyncCall: + return buf_create_from_str("async function called with noasync suspended"); } zig_unreachable(); } @@ -4067,6 +4069,25 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr } else if (instruction->modifier == CallModifierNoAsync && !fn_is_async(g->cur_fn)) { gen_resume(g, fn_val, frame_result_loc, ResumeIdCall); + if (ir_want_runtime_safety(g, &instruction->base)) { + LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, + frame_awaiter_index, ""); + LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref); + LLVMValueRef prev_val = gen_maybe_atomic_op(g, LLVMAtomicRMWBinOpXchg, awaiter_ptr, + all_ones, LLVMAtomicOrderingRelease); + LLVMValueRef ok_val = LLVMBuildICmp(g->builder, LLVMIntEQ, prev_val, all_ones, ""); + + LLVMBasicBlockRef bad_block = LLVMAppendBasicBlock(g->cur_fn_val, "NoAsyncPanic"); + LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "NoAsyncOk"); + LLVMBuildCondBr(g->builder, ok_val, ok_block, bad_block); + + // The async function suspended, but this noasync call asserted it wouldn't. + LLVMPositionBuilderAtEnd(g->builder, bad_block); + gen_safety_crash(g, PanicMsgIdBadNoAsyncCall); + + LLVMPositionBuilderAtEnd(g->builder, ok_block); + } + ZigType *result_type = instruction->base.value.type; ZigType *ptr_result_type = get_pointer_to_type(g, result_type, true); return gen_await_early_return(g, &instruction->base, frame_result_loc, diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig index 07a8c3910a..17f0f3230c 100644 --- a/test/runtime_safety.zig +++ b/test/runtime_safety.zig @@ -1,6 +1,21 @@ const tests = @import("tests.zig"); pub fn addCases(cases: *tests.CompareOutputContext) void { + cases.addRuntimeSafety("noasync function call, callee suspends", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\pub fn main() void { + \\ _ = noasync add(101, 100); + \\} + \\fn add(a: i32, b: i32) i32 { + \\ if (a > 100) { + \\ suspend; + \\ } + \\ return a + b; + \\} + ); + cases.addRuntimeSafety("awaiting twice", \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { \\ @import("std").os.exit(126);