diff --git a/example/guess_number/main.zig b/example/guess_number/main.zig index 4bba33d05f..1d111aaf08 100644 --- a/example/guess_number/main.zig +++ b/example/guess_number/main.zig @@ -12,7 +12,7 @@ pub fn main() -> %void { const seed = std.mem.readInt(seed_bytes, usize, true); var rand = Rand.init(seed); - const answer = rand.rangeUnsigned(u8, 0, 100) + 1; + const answer = rand.range(u8, 0, 100) + 1; while (true) { %%io.stdout.printf("\nGuess a number between 1 and 100: "); diff --git a/std/math.zig b/std/math.zig index 0dae77552f..a42266090c 100644 --- a/std/math.zig +++ b/std/math.zig @@ -41,6 +41,10 @@ pub fn sub(comptime T: type, a: T, b: T) -> %T { if (@subWithOverflow(T, a, b, &answer)) error.Overflow else answer } +pub fn negate(x: var) -> %@typeOf(x) { + return sub(@typeOf(x), 0, x); +} + error Overflow; pub fn shl(comptime T: type, a: T, b: T) -> %T { var answer: T = undefined; @@ -341,3 +345,51 @@ test "math.floor" { assert(floor(f64(999.0)) == 999.0); assert(floor(f64(-999.0)) == -999.0); } + +/// Returns the absolute value of the integer parameter. +/// Result is an unsigned integer. +pub fn absCast(x: var) -> @IntType(false, @typeOf(x).bit_count) { + const uint = @IntType(false, @typeOf(x).bit_count); + if (x >= 0) + return uint(x); + + return uint(-(x + 1)) + 1; +} + +test "math.absCast" { + assert(absCast(i32(-999)) == 999); + assert(@typeOf(absCast(i32(-999))) == u32); + + assert(absCast(i32(999)) == 999); + assert(@typeOf(absCast(i32(999))) == u32); + + assert(absCast(i32(@minValue(i32))) == -@minValue(i32)); + assert(@typeOf(absCast(i32(@minValue(i32)))) == u32); +} + +/// Returns the negation of the integer parameter. +/// Result is a signed integer. +error Overflow; +pub fn negateCast(x: var) -> %@IntType(true, @typeOf(x).bit_count) { + if (@typeOf(x).is_signed) + return negate(x); + + const int = @IntType(true, @typeOf(x).bit_count); + if (x > -@minValue(int)) + return error.Overflow; + + if (x == -@minValue(int)) + return @minValue(int); + + return -int(x); +} + +test "math.negateCast" { + assert(%%negateCast(u32(999)) == -999); + assert(@typeOf(%%negateCast(u32(999))) == i32); + + assert(%%negateCast(u32(-@minValue(i32))) == @minValue(i32)); + assert(@typeOf(%%negateCast(u32(-@minValue(i32)))) == i32); + + if (negateCast(u32(@maxValue(i32) + 10))) |_| unreachable else |err| assert(err == error.Overflow); +} diff --git a/std/rand.zig b/std/rand.zig index c55b366da9..63c65b7da9 100644 --- a/std/rand.zig +++ b/std/rand.zig @@ -1,6 +1,7 @@ const assert = @import("debug.zig").assert; const rand_test = @import("rand_test.zig"); const mem = @import("mem.zig"); +const math = @import("math.zig"); pub const MT19937_32 = MersenneTwister( u32, 624, 397, 31, @@ -63,18 +64,43 @@ pub const Rand = struct { /// Get a random unsigned integer with even distribution between `start` /// inclusive and `end` exclusive. - // TODO support signed integers and then rename to "range" - pub fn rangeUnsigned(r: &Rand, comptime T: type, start: T, end: T) -> T { - const range = end - start; - const leftover = @maxValue(T) % range; - const upper_bound = @maxValue(T) - leftover; - var rand_val_array: [@sizeOf(T)]u8 = undefined; + pub fn range(r: &Rand, comptime T: type, start: T, end: T) -> T { + assert(start <= end); + if (T.is_signed) { + const uint = @IntType(false, T.bit_count); + if (start >= 0 and end >= 0) { + return T(r.range(uint, uint(start), uint(end))); + } else if (start < 0 and end < 0) { + // Can't overflow because the range is over signed ints + return %%math.negateCast(r.range(uint, math.absCast(end), math.absCast(start)) + 1); + } else if (start < 0 and end >= 0) { + const end_uint = uint(end); + const total_range = math.absCast(start) + end_uint; + const value = r.range(uint, 0, total_range); + const result = if (value < end_uint) { + T(value) + } else if (value == end_uint) { + start + } else { + // Can't overflow because the range is over signed ints + %%math.negateCast(value - end_uint) + }; + return result; + } else { + unreachable; + } + } else { + const total_range = end - start; + const leftover = @maxValue(T) % total_range; + const upper_bound = @maxValue(T) - leftover; + var rand_val_array: [@sizeOf(T)]u8 = undefined; - while (true) { - r.fillBytes(rand_val_array[0..]); - const rand_val = mem.readInt(rand_val_array, T, false); - if (rand_val < upper_bound) { - return start + (rand_val % range); + while (true) { + r.fillBytes(rand_val_array[0..]); + const rand_val = mem.readInt(rand_val_array, T, false); + if (rand_val < upper_bound) { + return start + (rand_val % total_range); + } } } } @@ -94,7 +120,7 @@ pub const Rand = struct { } else { @compileError("unknown floating point type") }; - return T(r.rangeUnsigned(int_type, 0, precision)) / T(precision); + return T(r.range(int_type, 0, precision)) / T(precision); } }; @@ -175,16 +201,38 @@ test "rand float 32" { } } -test "testMT19937_64" { +test "rand.MT19937_64" { var rng = MT19937_64.init(rand_test.mt64_seed); for (rand_test.mt64_data) |value| { assert(value == rng.get()); } } -test "testMT19937_32" { +test "rand.MT19937_32" { var rng = MT19937_32.init(rand_test.mt32_seed); for (rand_test.mt32_data) |value| { assert(value == rng.get()); } } + +test "rand.Rand.range" { + var r = Rand.init(42); + testRange(&r, -4, 3); + testRange(&r, -4, -1); + testRange(&r, 10, 14); +} + +fn testRange(r: &Rand, start: i32, end: i32) { + const count = usize(end - start); + var values_buffer = []bool{false} ** 20; + const values = values_buffer[0..count]; + var i: usize = 0; + while (i < count) { + const value = r.range(i32, start, end); + const index = usize(value - start); + if (!values[index]) { + i += 1; + values[index] = true; + } + } +}