const std = @import("../std.zig"); const testing = std.testing; /// Performs multiple async functions in parallel, without heap allocation. /// Async function frames are managed externally to this abstraction, and /// passed in via the `add` function. Once all the jobs are added, call `wait`. /// This API is *not* thread-safe. The object must be accessed from one thread at /// a time, however, it need not be the same thread. pub fn Batch( /// The return value for each job. /// If a job slot was re-used due to maxed out concurrency, then its result /// value will be overwritten. The values can be accessed with the `results` field. comptime Result: type, /// How many jobs to run in parallel. comptime max_jobs: comptime_int, /// Controls whether the `add` and `wait` functions will be async functions. comptime async_behavior: enum { /// Observe the value of `std.io.is_async` to decide whether `add` /// and `wait` will be async functions. Asserts that the jobs do not suspend when /// `std.options.io_mode == .blocking`. This is a generally safe assumption, and the /// usual recommended option for this parameter. auto_async, /// Always uses the `nosuspend` keyword when using `await` on the jobs, /// making `add` and `wait` non-async functions. Asserts that the jobs do not suspend. never_async, /// `add` and `wait` use regular `await` keyword, making them async functions. always_async, }, ) type { return struct { jobs: [max_jobs]Job, next_job_index: usize, collected_result: CollectedResult, const Job = struct { frame: ?anyframe->Result, result: Result, }; const Self = @This(); const CollectedResult = switch (@typeInfo(Result)) { .ErrorUnion => Result, else => void, }; const async_ok = switch (async_behavior) { .auto_async => std.io.is_async, .never_async => false, .always_async => true, }; pub fn init() Self { return Self{ .jobs = [1]Job{ .{ .frame = null, .result = undefined, }, } ** max_jobs, .next_job_index = 0, .collected_result = {}, }; } /// Add a frame to the Batch. If all jobs are in-flight, then this function /// waits until one completes. /// This function is *not* thread-safe. It must be called from one thread at /// a time, however, it need not be the same thread. /// TODO: "select" language feature to use the next available slot, rather than /// awaiting the next index. pub fn add(self: *Self, frame: anyframe->Result) void { const job = &self.jobs[self.next_job_index]; self.next_job_index = (self.next_job_index + 1) % max_jobs; if (job.frame) |existing| { job.result = if (async_ok) await existing else nosuspend await existing; if (CollectedResult != void) { job.result catch |err| { self.collected_result = err; }; } } job.frame = frame; } /// Wait for all the jobs to complete. /// Safe to call any number of times. /// If `Result` is an error union, this function returns the last error that occurred, if any. /// Unlike the `results` field, the return value of `wait` will report any error that occurred; /// hitting max parallelism will not compromise the result. /// This function is *not* thread-safe. It must be called from one thread at /// a time, however, it need not be the same thread. pub fn wait(self: *Self) CollectedResult { for (self.jobs) |*job| if (job.frame) |f| { job.result = if (async_ok) await f else nosuspend await f; if (CollectedResult != void) { job.result catch |err| { self.collected_result = err; }; } job.frame = null; }; return self.collected_result; } }; } test "std.event.Batch" { if (true) return error.SkipZigTest; var count: usize = 0; var batch = Batch(void, 2, .auto_async).init(); batch.add(&async sleepALittle(&count)); batch.add(&async increaseByTen(&count)); batch.wait(); try testing.expect(count == 11); var another = Batch(anyerror!void, 2, .auto_async).init(); another.add(&async somethingElse()); another.add(&async doSomethingThatFails()); try testing.expectError(error.ItBroke, another.wait()); } fn sleepALittle(count: *usize) void { std.time.sleep(1 * std.time.ns_per_ms); _ = @atomicRmw(usize, count, .Add, 1, .SeqCst); } fn increaseByTen(count: *usize) void { var i: usize = 0; while (i < 10) : (i += 1) { _ = @atomicRmw(usize, count, .Add, 1, .SeqCst); } } fn doSomethingThatFails() anyerror!void {} fn somethingElse() anyerror!void { return error.ItBroke; }