From d346d074ebe5347f730a70d3a88b12f279bb405d Mon Sep 17 00:00:00 2001 From: ippsav <69125922+ippsav@users.noreply.github.com> Date: Mon, 11 Nov 2024 22:34:24 +0100 Subject: [PATCH] Enable thread_pool function to throw errors (#20260) * std.ThreadPool: allow error union return type * allow noreturn in Pool.zig --- lib/std/Thread/Pool.zig | 44 +++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/lib/std/Thread/Pool.zig b/lib/std/Thread/Pool.zig index 86bac7ce46..4dd7513373 100644 --- a/lib/std/Thread/Pool.zig +++ b/lib/std/Thread/Pool.zig @@ -97,7 +97,7 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args wait_group.start(); if (builtin.single_threaded) { - @call(.auto, func, args); + callFn(func, args); wait_group.finish(); return; } @@ -112,7 +112,7 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args fn runFn(runnable: *Runnable, _: ?usize) void { const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); - @call(.auto, func, closure.arguments); + callFn(func, closure.arguments); closure.wait_group.finish(); // The thread pool's allocator is protected by the mutex. @@ -129,7 +129,7 @@ pub fn spawnWg(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, args const closure = pool.allocator.create(Closure) catch { pool.mutex.unlock(); - @call(.auto, func, args); + callFn(func, args); wait_group.finish(); return; }; @@ -160,7 +160,7 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar wait_group.start(); if (builtin.single_threaded) { - @call(.auto, func, .{0} ++ args); + callFn(func, .{0} ++ args); wait_group.finish(); return; } @@ -175,7 +175,7 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar fn runFn(runnable: *Runnable, id: ?usize) void { const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); - @call(.auto, func, .{id.?} ++ closure.arguments); + callFn(func, .{id.?} ++ closure.arguments); closure.wait_group.finish(); // The thread pool's allocator is protected by the mutex. @@ -193,7 +193,7 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar const closure = pool.allocator.create(Closure) catch { const id: ?usize = pool.ids.getIndex(std.Thread.getCurrentId()); pool.mutex.unlock(); - @call(.auto, func, .{id.?} ++ args); + callFn(func, .{id.?} ++ args); wait_group.finish(); return; }; @@ -213,7 +213,7 @@ pub fn spawnWgId(pool: *Pool, wait_group: *WaitGroup, comptime func: anytype, ar pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void { if (builtin.single_threaded) { - @call(.auto, func, args); + callFn(func, args); return; } @@ -226,7 +226,7 @@ pub fn spawn(pool: *Pool, comptime func: anytype, args: anytype) !void { fn runFn(runnable: *Runnable, _: ?usize) void { const run_node: *RunQueue.Node = @fieldParentPtr("data", runnable); const closure: *@This() = @alignCast(@fieldParentPtr("run_node", run_node)); - @call(.auto, func, closure.arguments); + callFn(func, closure.arguments); // The thread pool's allocator is protected by the mutex. const mutex = &closure.pool.mutex; @@ -321,3 +321,31 @@ pub fn waitAndWork(pool: *Pool, wait_group: *WaitGroup) void { pub fn getIdCount(pool: *Pool) usize { return @intCast(1 + pool.threads.len); } + +inline fn callFn(comptime f: anytype, args: anytype) void { + const bad_fn_ret = "expected return type of runFn to be 'void', '!void', noreturn, or !noreturn"; + + switch (@typeInfo(@typeInfo(@TypeOf(f)).@"fn".return_type.?)) { + .void, .noreturn => { + @call(.auto, f, args); + }, + .error_union => |info| { + switch (info.payload) { + void, noreturn => { + @call(.auto, f, args) catch |err| { + std.debug.print("error: {s}\n", .{@errorName(err)}); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } + }; + }, + else => { + @compileError(bad_fn_ret); + }, + } + }, + else => { + @compileError(bad_fn_ret); + }, + } +}