change std.rand.Rand.rangeUnsigned to std.rand.Rand.range

and make it support signed integers
This commit is contained in:
Andrew Kelley 2017-05-31 18:23:56 -04:00
parent 1ae2002b41
commit 1e301b03a9
3 changed files with 115 additions and 15 deletions

View File

@ -12,7 +12,7 @@ pub fn main() -> %void {
const seed = std.mem.readInt(seed_bytes, usize, true);
var rand = Rand.init(seed);
const answer = rand.rangeUnsigned(u8, 0, 100) + 1;
const answer = rand.range(u8, 0, 100) + 1;
while (true) {
%%io.stdout.printf("\nGuess a number between 1 and 100: ");

View File

@ -41,6 +41,10 @@ pub fn sub(comptime T: type, a: T, b: T) -> %T {
if (@subWithOverflow(T, a, b, &answer)) error.Overflow else answer
}
pub fn negate(x: var) -> %@typeOf(x) {
return sub(@typeOf(x), 0, x);
}
error Overflow;
pub fn shl(comptime T: type, a: T, b: T) -> %T {
var answer: T = undefined;
@ -341,3 +345,51 @@ test "math.floor" {
assert(floor(f64(999.0)) == 999.0);
assert(floor(f64(-999.0)) == -999.0);
}
/// Returns the absolute value of the integer parameter.
/// Result is an unsigned integer.
pub fn absCast(x: var) -> @IntType(false, @typeOf(x).bit_count) {
const uint = @IntType(false, @typeOf(x).bit_count);
if (x >= 0)
return uint(x);
return uint(-(x + 1)) + 1;
}
test "math.absCast" {
assert(absCast(i32(-999)) == 999);
assert(@typeOf(absCast(i32(-999))) == u32);
assert(absCast(i32(999)) == 999);
assert(@typeOf(absCast(i32(999))) == u32);
assert(absCast(i32(@minValue(i32))) == -@minValue(i32));
assert(@typeOf(absCast(i32(@minValue(i32)))) == u32);
}
/// Returns the negation of the integer parameter.
/// Result is a signed integer.
error Overflow;
pub fn negateCast(x: var) -> %@IntType(true, @typeOf(x).bit_count) {
if (@typeOf(x).is_signed)
return negate(x);
const int = @IntType(true, @typeOf(x).bit_count);
if (x > -@minValue(int))
return error.Overflow;
if (x == -@minValue(int))
return @minValue(int);
return -int(x);
}
test "math.negateCast" {
assert(%%negateCast(u32(999)) == -999);
assert(@typeOf(%%negateCast(u32(999))) == i32);
assert(%%negateCast(u32(-@minValue(i32))) == @minValue(i32));
assert(@typeOf(%%negateCast(u32(-@minValue(i32)))) == i32);
if (negateCast(u32(@maxValue(i32) + 10))) |_| unreachable else |err| assert(err == error.Overflow);
}

View File

@ -1,6 +1,7 @@
const assert = @import("debug.zig").assert;
const rand_test = @import("rand_test.zig");
const mem = @import("mem.zig");
const math = @import("math.zig");
pub const MT19937_32 = MersenneTwister(
u32, 624, 397, 31,
@ -63,18 +64,43 @@ pub const Rand = struct {
/// Get a random unsigned integer with even distribution between `start`
/// inclusive and `end` exclusive.
// TODO support signed integers and then rename to "range"
pub fn rangeUnsigned(r: &Rand, comptime T: type, start: T, end: T) -> T {
const range = end - start;
const leftover = @maxValue(T) % range;
const upper_bound = @maxValue(T) - leftover;
var rand_val_array: [@sizeOf(T)]u8 = undefined;
pub fn range(r: &Rand, comptime T: type, start: T, end: T) -> T {
assert(start <= end);
if (T.is_signed) {
const uint = @IntType(false, T.bit_count);
if (start >= 0 and end >= 0) {
return T(r.range(uint, uint(start), uint(end)));
} else if (start < 0 and end < 0) {
// Can't overflow because the range is over signed ints
return %%math.negateCast(r.range(uint, math.absCast(end), math.absCast(start)) + 1);
} else if (start < 0 and end >= 0) {
const end_uint = uint(end);
const total_range = math.absCast(start) + end_uint;
const value = r.range(uint, 0, total_range);
const result = if (value < end_uint) {
T(value)
} else if (value == end_uint) {
start
} else {
// Can't overflow because the range is over signed ints
%%math.negateCast(value - end_uint)
};
return result;
} else {
unreachable;
}
} else {
const total_range = end - start;
const leftover = @maxValue(T) % total_range;
const upper_bound = @maxValue(T) - leftover;
var rand_val_array: [@sizeOf(T)]u8 = undefined;
while (true) {
r.fillBytes(rand_val_array[0..]);
const rand_val = mem.readInt(rand_val_array, T, false);
if (rand_val < upper_bound) {
return start + (rand_val % range);
while (true) {
r.fillBytes(rand_val_array[0..]);
const rand_val = mem.readInt(rand_val_array, T, false);
if (rand_val < upper_bound) {
return start + (rand_val % total_range);
}
}
}
}
@ -94,7 +120,7 @@ pub const Rand = struct {
} else {
@compileError("unknown floating point type")
};
return T(r.rangeUnsigned(int_type, 0, precision)) / T(precision);
return T(r.range(int_type, 0, precision)) / T(precision);
}
};
@ -175,16 +201,38 @@ test "rand float 32" {
}
}
test "testMT19937_64" {
test "rand.MT19937_64" {
var rng = MT19937_64.init(rand_test.mt64_seed);
for (rand_test.mt64_data) |value| {
assert(value == rng.get());
}
}
test "testMT19937_32" {
test "rand.MT19937_32" {
var rng = MT19937_32.init(rand_test.mt32_seed);
for (rand_test.mt32_data) |value| {
assert(value == rng.get());
}
}
test "rand.Rand.range" {
var r = Rand.init(42);
testRange(&r, -4, 3);
testRange(&r, -4, -1);
testRange(&r, 10, 14);
}
fn testRange(r: &Rand, start: i32, end: i32) {
const count = usize(end - start);
var values_buffer = []bool{false} ** 20;
const values = values_buffer[0..count];
var i: usize = 0;
while (i < count) {
const value = r.range(i32, start, end);
const index = usize(value - start);
if (!values[index]) {
i += 1;
values[index] = true;
}
}
}