stage1: Implement @intCast between vectors

Explicit and implicit integer casts on vector types are now supported
and follow the same rules as their scalar counterparts.

Implicit float casts are accidentally supported, `@floatCast` is still
not vector-aware.
This commit is contained in:
LemonBoy 2020-10-17 09:46:11 +02:00
parent 245d98d32d
commit 2f465761bb
5 changed files with 228 additions and 39 deletions

View File

@ -1433,6 +1433,9 @@ static void add_sentinel_check(CodeGen *g, LLVMValueRef sentinel_elem_ptr, ZigVa
static LLVMValueRef gen_assert_zero(CodeGen *g, LLVMValueRef expr_val, ZigType *int_type) {
LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, int_type));
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, zero, "");
if (int_type->id == ZigTypeIdVector) {
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenFail");
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
@ -1450,29 +1453,37 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
assert(actual_type->id == wanted_type->id);
assert(expr_val != nullptr);
ZigType *scalar_actual_type = (actual_type->id == ZigTypeIdVector) ?
actual_type->data.vector.elem_type : actual_type;
ZigType *scalar_wanted_type = (wanted_type->id == ZigTypeIdVector) ?
wanted_type->data.vector.elem_type : wanted_type;
uint64_t actual_bits;
uint64_t wanted_bits;
if (actual_type->id == ZigTypeIdFloat) {
actual_bits = actual_type->data.floating.bit_count;
wanted_bits = wanted_type->data.floating.bit_count;
} else if (actual_type->id == ZigTypeIdInt) {
actual_bits = actual_type->data.integral.bit_count;
wanted_bits = wanted_type->data.integral.bit_count;
if (scalar_actual_type->id == ZigTypeIdFloat) {
actual_bits = scalar_actual_type->data.floating.bit_count;
wanted_bits = scalar_wanted_type->data.floating.bit_count;
} else if (scalar_actual_type->id == ZigTypeIdInt) {
actual_bits = scalar_actual_type->data.integral.bit_count;
wanted_bits = scalar_wanted_type->data.integral.bit_count;
} else {
zig_unreachable();
}
if (actual_type->id == ZigTypeIdInt && want_runtime_safety && (
if (scalar_actual_type->id == ZigTypeIdInt && want_runtime_safety && (
// negative to unsigned
(!wanted_type->data.integral.is_signed && actual_type->data.integral.is_signed) ||
(!scalar_wanted_type->data.integral.is_signed && scalar_actual_type->data.integral.is_signed) ||
// unsigned would become negative
(wanted_type->data.integral.is_signed && !actual_type->data.integral.is_signed && actual_bits == wanted_bits)))
(scalar_wanted_type->data.integral.is_signed && !scalar_actual_type->data.integral.is_signed && actual_bits == wanted_bits)))
{
LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, actual_type));
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntSGE, expr_val, zero, "");
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "SignCastOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "SignCastFail");
if (actual_type->id == ZigTypeIdVector) {
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
LLVMPositionBuilderAtEnd(g->builder, fail_block);
@ -1484,10 +1495,10 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
if (actual_bits == wanted_bits) {
return expr_val;
} else if (actual_bits < wanted_bits) {
if (actual_type->id == ZigTypeIdFloat) {
if (scalar_actual_type->id == ZigTypeIdFloat) {
return LLVMBuildFPExt(g->builder, expr_val, get_llvm_type(g, wanted_type), "");
} else if (actual_type->id == ZigTypeIdInt) {
if (actual_type->data.integral.is_signed) {
} else if (scalar_actual_type->id == ZigTypeIdInt) {
if (scalar_actual_type->data.integral.is_signed) {
return LLVMBuildSExt(g->builder, expr_val, get_llvm_type(g, wanted_type), "");
} else {
return LLVMBuildZExt(g->builder, expr_val, get_llvm_type(g, wanted_type), "");
@ -1496,9 +1507,9 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
zig_unreachable();
}
} else if (actual_bits > wanted_bits) {
if (actual_type->id == ZigTypeIdFloat) {
if (scalar_actual_type->id == ZigTypeIdFloat) {
return LLVMBuildFPTrunc(g->builder, expr_val, get_llvm_type(g, wanted_type), "");
} else if (actual_type->id == ZigTypeIdInt) {
} else if (scalar_actual_type->id == ZigTypeIdInt) {
if (wanted_bits == 0) {
if (!want_runtime_safety)
return nullptr;
@ -1510,12 +1521,15 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
return trunc_val;
}
LLVMValueRef orig_val;
if (wanted_type->data.integral.is_signed) {
if (scalar_wanted_type->data.integral.is_signed) {
orig_val = LLVMBuildSExt(g->builder, trunc_val, get_llvm_type(g, actual_type), "");
} else {
orig_val = LLVMBuildZExt(g->builder, trunc_val, get_llvm_type(g, actual_type), "");
}
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, orig_val, "");
if (actual_type->id == ZigTypeIdVector) {
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenFail");
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

View File

@ -86,6 +86,8 @@ enum ConstCastResultId {
ConstCastResultIdCV,
ConstCastResultIdPtrSentinel,
ConstCastResultIdIntShorten,
ConstCastResultIdVectorLength,
ConstCastResultIdVectorChild,
};
struct ConstCastOnly;
@ -914,6 +916,7 @@ static bool types_have_same_zig_comptime_repr(CodeGen *codegen, ZigType *expecte
if (is_opt_err_set(expected) && is_opt_err_set(actual))
return true;
// XXX: Vectors and arrays are interchangeable at comptime
if (expected->id != actual->id)
return false;
@ -947,9 +950,11 @@ static bool types_have_same_zig_comptime_repr(CodeGen *codegen, ZigType *expecte
case ZigTypeIdErrorUnion:
case ZigTypeIdEnum:
case ZigTypeIdUnion:
case ZigTypeIdVector:
case ZigTypeIdFnFrame:
return false;
case ZigTypeIdVector:
return expected->data.vector.len == actual->data.vector.len &&
types_have_same_zig_comptime_repr(codegen, expected->data.vector.elem_type, actual->data.vector.elem_type);
case ZigTypeIdArray:
return expected->data.array.len == actual->data.array.len &&
expected->data.array.child_type == actual->data.array.child_type &&
@ -12190,6 +12195,24 @@ static ConstCastOnly types_match_const_cast_only(IrAnalyze *ira, ZigType *wanted
return result;
}
if (wanted_type->id == ZigTypeIdVector && actual_type->id == ZigTypeIdVector) {
if (actual_type->data.vector.len != wanted_type->data.vector.len) {
result.id = ConstCastResultIdVectorLength;
return result;
}
ConstCastOnly child = types_match_const_cast_only(ira, wanted_type->data.vector.elem_type,
actual_type->data.vector.elem_type, source_node, false);
if (child.id == ConstCastResultIdInvalid)
return child;
if (child.id != ConstCastResultIdOk) {
result.id = ConstCastResultIdVectorChild;
return result;
}
return result;
}
result.id = ConstCastResultIdType;
result.data.type_mismatch = heap::c_allocator.allocate_nonzero<ConstCastTypeMismatch>(1);
result.data.type_mismatch->wanted_type = wanted_type;
@ -14306,37 +14329,62 @@ static IrInstGen *ir_analyze_enum_to_union(IrAnalyze *ira, IrInst* source_instr,
return ira->codegen->invalid_inst_gen;
}
static bool value_numeric_fits_in_type(ZigValue *value, ZigType *type_entry);
static IrInstGen *ir_analyze_widen_or_shorten(IrAnalyze *ira, IrInst* source_instr,
IrInstGen *target, ZigType *wanted_type)
{
assert(wanted_type->id == ZigTypeIdInt || wanted_type->id == ZigTypeIdFloat);
ZigType *wanted_scalar_type = (target->value->type->id == ZigTypeIdVector) ?
wanted_type->data.vector.elem_type : wanted_type;
assert(wanted_scalar_type->id == ZigTypeIdInt || wanted_scalar_type->id == ZigTypeIdFloat);
if (instr_is_comptime(target)) {
ZigValue *val = ir_resolve_const(ira, target, UndefBad);
if (!val)
return ira->codegen->invalid_inst_gen;
if (wanted_type->id == ZigTypeIdInt) {
if (bigint_cmp_zero(&val->data.x_bigint) == CmpLT && !wanted_type->data.integral.is_signed) {
if (wanted_scalar_type->id == ZigTypeIdInt) {
if (!wanted_scalar_type->data.integral.is_signed && value_cmp_numeric_val_any(val, CmpLT, nullptr)) {
ir_add_error(ira, source_instr,
buf_sprintf("attempt to cast negative value to unsigned integer"));
return ira->codegen->invalid_inst_gen;
}
if (!bigint_fits_in_bits(&val->data.x_bigint, wanted_type->data.integral.bit_count,
wanted_type->data.integral.is_signed))
{
if (!value_numeric_fits_in_type(val, wanted_scalar_type)) {
ir_add_error(ira, source_instr,
buf_sprintf("cast from '%s' to '%s' truncates bits",
buf_ptr(&target->value->type->name), buf_ptr(&wanted_type->name)));
buf_ptr(&target->value->type->name), buf_ptr(&wanted_scalar_type->name)));
return ira->codegen->invalid_inst_gen;
}
}
IrInstGen *result = ir_const(ira, source_instr, wanted_type);
result->value->type = wanted_type;
if (wanted_type->id == ZigTypeIdInt) {
bigint_init_bigint(&result->value->data.x_bigint, &val->data.x_bigint);
if (wanted_type->id == ZigTypeIdVector) {
result->value->data.x_array.data.s_none.elements = ira->codegen->pass1_arena->allocate<ZigValue>(wanted_type->data.vector.len);
for (size_t i = 0; i < wanted_type->data.vector.len; i++) {
ZigValue *scalar_dest_value = &result->value->data.x_array.data.s_none.elements[i];
ZigValue *scalar_src_value = &val->data.x_array.data.s_none.elements[i];
scalar_dest_value->type = wanted_scalar_type;
scalar_dest_value->special = ConstValSpecialStatic;
if (wanted_scalar_type->id == ZigTypeIdInt) {
bigint_init_bigint(&scalar_dest_value->data.x_bigint, &scalar_src_value->data.x_bigint);
} else {
float_init_float(scalar_dest_value, scalar_src_value);
}
}
} else {
float_init_float(result->value, val);
if (wanted_type->id == ZigTypeIdInt) {
bigint_init_bigint(&result->value->data.x_bigint, &val->data.x_bigint);
} else {
float_init_float(result->value, val);
}
}
return result;
}
@ -14779,6 +14827,8 @@ static void report_recursive_error(IrAnalyze *ira, AstNode *source_node, ConstCa
actual_signed, actual_type->data.integral.bit_count));
break;
}
case ConstCastResultIdVectorLength: // TODO
case ConstCastResultIdVectorChild: // TODO
case ConstCastResultIdFnAlign: // TODO
case ConstCastResultIdFnVarArgs: // TODO
case ConstCastResultIdFnReturnType: // TODO
@ -15462,12 +15512,35 @@ static IrInstGen *ir_analyze_cast(IrAnalyze *ira, IrInst *source_instr,
}
// @Vector(N,T1) to @Vector(N,T2)
if (actual_type->id == ZigTypeIdVector && wanted_type->id == ZigTypeIdVector) {
if (actual_type->data.vector.len == wanted_type->data.vector.len &&
types_match_const_cast_only(ira, wanted_type->data.vector.elem_type,
actual_type->data.vector.elem_type, source_node, false).id == ConstCastResultIdOk)
if (actual_type->id == ZigTypeIdVector && wanted_type->id == ZigTypeIdVector &&
actual_type->data.vector.len == wanted_type->data.vector.len)
{
ZigType *scalar_actual_type = actual_type->data.vector.elem_type;
ZigType *scalar_wanted_type = wanted_type->data.vector.elem_type;
// widening conversion
if (scalar_wanted_type->id == ZigTypeIdInt &&
scalar_actual_type->id == ZigTypeIdInt &&
scalar_wanted_type->data.integral.is_signed == scalar_actual_type->data.integral.is_signed &&
scalar_wanted_type->data.integral.bit_count >= scalar_actual_type->data.integral.bit_count)
{
return ir_analyze_bit_cast(ira, source_instr, value, wanted_type);
return ir_analyze_widen_or_shorten(ira, source_instr, value, wanted_type);
}
// small enough unsigned ints can get casted to large enough signed ints
if (scalar_wanted_type->id == ZigTypeIdInt && scalar_wanted_type->data.integral.is_signed &&
scalar_actual_type->id == ZigTypeIdInt && !scalar_actual_type->data.integral.is_signed &&
scalar_wanted_type->data.integral.bit_count > scalar_actual_type->data.integral.bit_count)
{
return ir_analyze_widen_or_shorten(ira, source_instr, value, wanted_type);
}
// float widening conversion
if (scalar_wanted_type->id == ZigTypeIdFloat &&
scalar_actual_type->id == ZigTypeIdFloat &&
scalar_wanted_type->data.floating.bit_count >= scalar_actual_type->data.floating.bit_count)
{
return ir_analyze_widen_or_shorten(ira, source_instr, value, wanted_type);
}
}
@ -17728,6 +17801,33 @@ static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) {
zig_unreachable();
}
// Returns true if integer `value` can be converted to `type_entry` without
// losing data.
// If `value` is a vector the function returns true if this is valid for every
// element.
static bool value_numeric_fits_in_type(ZigValue *value, ZigType *type_entry) {
assert(value->special == ConstValSpecialStatic);
assert(type_entry->id == ZigTypeIdInt);
switch (value->type->id) {
case ZigTypeIdComptimeInt:
case ZigTypeIdInt: {
return bigint_fits_in_bits(&value->data.x_bigint, type_entry->data.integral.bit_count,
type_entry->data.integral.is_signed);
}
case ZigTypeIdVector: {
for (size_t i = 0; i < value->type->data.vector.len; i++) {
ZigValue *scalar_value = &value->data.x_array.data.s_none.elements[i];
const bool result = bigint_fits_in_bits(&scalar_value->data.x_bigint,
type_entry->data.integral.bit_count, type_entry->data.integral.is_signed);
if (!result) return false;
}
return true;
}
default: zig_unreachable();
}
}
static bool value_cmp_numeric_val(ZigValue *left, Cmp predicate, ZigValue *right, bool any) {
assert(left->special == ConstValSpecialStatic);
assert(right == nullptr || right->special == ConstValSpecialStatic);
@ -27154,8 +27254,12 @@ static IrInstGen *ir_analyze_instruction_int_cast(IrAnalyze *ira, IrInstSrcIntCa
if (type_is_invalid(dest_type))
return ira->codegen->invalid_inst_gen;
if (dest_type->id != ZigTypeIdInt && dest_type->id != ZigTypeIdComptimeInt) {
ir_add_error(ira, &instruction->dest_type->base, buf_sprintf("expected integer type, found '%s'", buf_ptr(&dest_type->name)));
ZigType *scalar_dest_type = (dest_type->id == ZigTypeIdVector) ?
dest_type->data.vector.elem_type : dest_type;
if (scalar_dest_type->id != ZigTypeIdInt && scalar_dest_type->id != ZigTypeIdComptimeInt) {
ir_add_error(ira, &instruction->dest_type->base,
buf_sprintf("expected integer type, found '%s'", buf_ptr(&scalar_dest_type->name)));
return ira->codegen->invalid_inst_gen;
}
@ -27163,13 +27267,16 @@ static IrInstGen *ir_analyze_instruction_int_cast(IrAnalyze *ira, IrInstSrcIntCa
if (type_is_invalid(target->value->type))
return ira->codegen->invalid_inst_gen;
if (target->value->type->id != ZigTypeIdInt && target->value->type->id != ZigTypeIdComptimeInt) {
ZigType *scalar_target_type = (target->value->type->id == ZigTypeIdVector) ?
target->value->type->data.vector.elem_type : target->value->type;
if (scalar_target_type->id != ZigTypeIdInt && scalar_target_type->id != ZigTypeIdComptimeInt) {
ir_add_error(ira, &instruction->target->base, buf_sprintf("expected integer type, found '%s'",
buf_ptr(&target->value->type->name)));
buf_ptr(&scalar_target_type->name)));
return ira->codegen->invalid_inst_gen;
}
if (instr_is_comptime(target) || dest_type->id == ZigTypeIdComptimeInt) {
if (scalar_dest_type->id == ZigTypeIdComptimeInt) {
ZigValue *val = ir_resolve_const(ira, target, UndefBad);
if (val == nullptr)
return ira->codegen->invalid_inst_gen;
@ -27222,6 +27329,7 @@ static IrInstGen *ir_analyze_instruction_float_cast(IrAnalyze *ira, IrInstSrcFlo
if (val == nullptr)
return ira->codegen->invalid_inst_gen;
// XXX: This will trigger an assertion failure if dest_type is comptime_float
return ir_analyze_widen_or_shorten(ira, &instruction->target->base, target, dest_type);
}

View File

@ -2944,7 +2944,6 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
"tmp.zig:4:18: error: expected type 'fn(i32) void', found 'fn(bool) void",
"tmp.zig:4:18: note: parameter 0: 'bool' cannot cast into 'i32'",
});
cases.add("cast negative value to unsigned integer",
\\comptime {
\\ const value: i32 = -1;
@ -2955,7 +2954,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
\\ const unsigned: u32 = value;
\\}
, &[_][]const u8{
"tmp.zig:3:36: error: cannot cast negative value -1 to unsigned integer type 'u32'",
"tmp.zig:3:22: error: attempt to cast negative value to unsigned integer",
"tmp.zig:7:27: error: cannot cast negative value -1 to unsigned integer type 'u32'",
});
@ -2977,7 +2976,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
\\ var unsigned: u64 = signed;
\\}
, &[_][]const u8{
"tmp.zig:3:31: error: integer value 300 cannot be coerced to type 'u8'",
"tmp.zig:3:18: error: cast from 'u16' to 'u8' truncates bits",
"tmp.zig:7:22: error: integer value 300 cannot be coerced to type 'u8'",
"tmp.zig:11:20: error: expected type 'u8', found 'u16'",
"tmp.zig:11:20: note: unsigned 8-bit int cannot represent all possible unsigned 16-bit values",

View File

@ -70,6 +70,36 @@ pub fn addCases(cases: *tests.CompareOutputContext) void {
);
}
cases.addRuntimeSafety("truncating vector cast",
\\const std = @import("std");
\\const V = @import("std").meta.Vector;
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
\\ if (std.mem.eql(u8, message, "integer cast truncated bits")) {
\\ std.process.exit(126); // good
\\ }
\\ std.process.exit(0); // test failed
\\}
\\pub fn main() void {
\\ var x = @splat(4, @as(u32, 0xdeadbeef));
\\ var y = @intCast(V(4, u16), x);
\\}
);
cases.addRuntimeSafety("unsigned-signed vector cast",
\\const std = @import("std");
\\const V = @import("std").meta.Vector;
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
\\ if (std.mem.eql(u8, message, "attempt to cast negative value to unsigned integer")) {
\\ std.process.exit(126); // good
\\ }
\\ std.process.exit(0); // test failed
\\}
\\pub fn main() void {
\\ var x = @splat(4, @as(u32, 0x80000000));
\\ var y = @intCast(V(4, i32), x);
\\}
);
cases.addRuntimeSafety("shift left by huge amount",
\\const std = @import("std");
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {

View File

@ -2,6 +2,7 @@ const std = @import("std");
const expect = std.testing.expect;
const mem = std.mem;
const maxInt = std.math.maxInt;
const Vector = std.meta.Vector;
test "int to ptr cast" {
const x = @as(usize, 13);
@ -364,6 +365,43 @@ test "@floatCast comptime_int and comptime_float" {
}
}
test "vector casts" {
const S = struct {
fn doTheTest() void {
// Upcast (implicit, equivalent to @intCast)
var up0: Vector(2, u8) = [_]u8{ 0x55, 0xaa };
var up1 = @as(Vector(2, u16), up0);
var up2 = @as(Vector(2, u32), up0);
var up3 = @as(Vector(2, u64), up0);
// Downcast (safety-checked)
var down0 = up3;
var down1 = @intCast(Vector(2, u32), down0);
var down2 = @intCast(Vector(2, u16), down0);
var down3 = @intCast(Vector(2, u8), down0);
expect(mem.eql(u16, &@as([2]u16, up1), &[2]u16{ 0x55, 0xaa }));
expect(mem.eql(u32, &@as([2]u32, up2), &[2]u32{ 0x55, 0xaa }));
expect(mem.eql(u64, &@as([2]u64, up3), &[2]u64{ 0x55, 0xaa }));
expect(mem.eql(u32, &@as([2]u32, down1), &[2]u32{ 0x55, 0xaa }));
expect(mem.eql(u16, &@as([2]u16, down2), &[2]u16{ 0x55, 0xaa }));
expect(mem.eql(u8, &@as([2]u8, down3), &[2]u8{ 0x55, 0xaa }));
}
fn doTheTestFloat() void {
var vec = @splat(2, @as(f32, 1234.0));
var wider: Vector(2, f64) = vec;
expect(wider[0] == 1234.0);
expect(wider[1] == 1234.0);
}
};
S.doTheTest();
comptime S.doTheTest();
S.doTheTestFloat();
comptime S.doTheTestFloat();
}
test "comptime_int @intToFloat" {
{
const result = @intToFloat(f16, 1234);