mirror of
https://github.com/ziglang/zig.git
synced 2024-11-27 15:42:49 +00:00
4a69b11e74
add SPDX license identifier copyright ownership is zig contributors
145 lines
5.4 KiB
Zig
145 lines
5.4 KiB
Zig
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2015-2020 Zig Contributors
|
|
// This file is part of [zig](https://ziglang.org/), which is MIT licensed.
|
|
// The MIT license requires this copyright notice to be included in all copies
|
|
// and substantial portions of the software.
|
|
const std = @import("../std.zig");
|
|
const testing = std.testing;
|
|
|
|
/// Performs multiple async functions in parallel, without heap allocation.
|
|
/// Async function frames are managed externally to this abstraction, and
|
|
/// passed in via the `add` function. Once all the jobs are added, call `wait`.
|
|
/// This API is *not* thread-safe. The object must be accessed from one thread at
|
|
/// a time, however, it need not be the same thread.
|
|
pub fn Batch(
|
|
/// The return value for each job.
|
|
/// If a job slot was re-used due to maxed out concurrency, then its result
|
|
/// value will be overwritten. The values can be accessed with the `results` field.
|
|
comptime Result: type,
|
|
/// How many jobs to run in parallel.
|
|
comptime max_jobs: comptime_int,
|
|
/// Controls whether the `add` and `wait` functions will be async functions.
|
|
comptime async_behavior: enum {
|
|
/// Observe the value of `std.io.is_async` to decide whether `add`
|
|
/// and `wait` will be async functions. Asserts that the jobs do not suspend when
|
|
/// `std.io.mode == .blocking`. This is a generally safe assumption, and the
|
|
/// usual recommended option for this parameter.
|
|
auto_async,
|
|
|
|
/// Always uses the `nosuspend` keyword when using `await` on the jobs,
|
|
/// making `add` and `wait` non-async functions. Asserts that the jobs do not suspend.
|
|
never_async,
|
|
|
|
/// `add` and `wait` use regular `await` keyword, making them async functions.
|
|
always_async,
|
|
},
|
|
) type {
|
|
return struct {
|
|
jobs: [max_jobs]Job,
|
|
next_job_index: usize,
|
|
collected_result: CollectedResult,
|
|
|
|
const Job = struct {
|
|
frame: ?anyframe->Result,
|
|
result: Result,
|
|
};
|
|
|
|
const Self = @This();
|
|
|
|
const CollectedResult = switch (@typeInfo(Result)) {
|
|
.ErrorUnion => Result,
|
|
else => void,
|
|
};
|
|
|
|
const async_ok = switch (async_behavior) {
|
|
.auto_async => std.io.is_async,
|
|
.never_async => false,
|
|
.always_async => true,
|
|
};
|
|
|
|
pub fn init() Self {
|
|
return Self{
|
|
.jobs = [1]Job{
|
|
.{
|
|
.frame = null,
|
|
.result = undefined,
|
|
},
|
|
} ** max_jobs,
|
|
.next_job_index = 0,
|
|
.collected_result = {},
|
|
};
|
|
}
|
|
|
|
/// Add a frame to the Batch. If all jobs are in-flight, then this function
|
|
/// waits until one completes.
|
|
/// This function is *not* thread-safe. It must be called from one thread at
|
|
/// a time, however, it need not be the same thread.
|
|
/// TODO: "select" language feature to use the next available slot, rather than
|
|
/// awaiting the next index.
|
|
pub fn add(self: *Self, frame: anyframe->Result) void {
|
|
const job = &self.jobs[self.next_job_index];
|
|
self.next_job_index = (self.next_job_index + 1) % max_jobs;
|
|
if (job.frame) |existing| {
|
|
job.result = if (async_ok) await existing else nosuspend await existing;
|
|
if (CollectedResult != void) {
|
|
job.result catch |err| {
|
|
self.collected_result = err;
|
|
};
|
|
}
|
|
}
|
|
job.frame = frame;
|
|
}
|
|
|
|
/// Wait for all the jobs to complete.
|
|
/// Safe to call any number of times.
|
|
/// If `Result` is an error union, this function returns the last error that occurred, if any.
|
|
/// Unlike the `results` field, the return value of `wait` will report any error that occurred;
|
|
/// hitting max parallelism will not compromise the result.
|
|
/// This function is *not* thread-safe. It must be called from one thread at
|
|
/// a time, however, it need not be the same thread.
|
|
pub fn wait(self: *Self) CollectedResult {
|
|
for (self.jobs) |*job| if (job.frame) |f| {
|
|
job.result = if (async_ok) await f else nosuspend await f;
|
|
if (CollectedResult != void) {
|
|
job.result catch |err| {
|
|
self.collected_result = err;
|
|
};
|
|
}
|
|
job.frame = null;
|
|
};
|
|
return self.collected_result;
|
|
}
|
|
};
|
|
}
|
|
|
|
test "std.event.Batch" {
|
|
var count: usize = 0;
|
|
var batch = Batch(void, 2, .auto_async).init();
|
|
batch.add(&async sleepALittle(&count));
|
|
batch.add(&async increaseByTen(&count));
|
|
batch.wait();
|
|
testing.expect(count == 11);
|
|
|
|
var another = Batch(anyerror!void, 2, .auto_async).init();
|
|
another.add(&async somethingElse());
|
|
another.add(&async doSomethingThatFails());
|
|
testing.expectError(error.ItBroke, another.wait());
|
|
}
|
|
|
|
fn sleepALittle(count: *usize) void {
|
|
std.time.sleep(1 * std.time.ns_per_ms);
|
|
_ = @atomicRmw(usize, count, .Add, 1, .SeqCst);
|
|
}
|
|
|
|
fn increaseByTen(count: *usize) void {
|
|
var i: usize = 0;
|
|
while (i < 10) : (i += 1) {
|
|
_ = @atomicRmw(usize, count, .Add, 1, .SeqCst);
|
|
}
|
|
}
|
|
|
|
fn doSomethingThatFails() anyerror!void {}
|
|
fn somethingElse() anyerror!void {
|
|
return error.ItBroke;
|
|
}
|