diff --git a/src/stage1/codegen.cpp b/src/stage1/codegen.cpp index b8ac867082..825d05a31c 100644 --- a/src/stage1/codegen.cpp +++ b/src/stage1/codegen.cpp @@ -1433,6 +1433,9 @@ static void add_sentinel_check(CodeGen *g, LLVMValueRef sentinel_elem_ptr, ZigVa static LLVMValueRef gen_assert_zero(CodeGen *g, LLVMValueRef expr_val, ZigType *int_type) { LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, int_type)); LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, zero, ""); + if (int_type->id == ZigTypeIdVector) { + ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit); + } LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenFail"); LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); @@ -1450,29 +1453,37 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z assert(actual_type->id == wanted_type->id); assert(expr_val != nullptr); + ZigType *scalar_actual_type = (actual_type->id == ZigTypeIdVector) ? + actual_type->data.vector.elem_type : actual_type; + ZigType *scalar_wanted_type = (wanted_type->id == ZigTypeIdVector) ? + wanted_type->data.vector.elem_type : wanted_type; + uint64_t actual_bits; uint64_t wanted_bits; - if (actual_type->id == ZigTypeIdFloat) { - actual_bits = actual_type->data.floating.bit_count; - wanted_bits = wanted_type->data.floating.bit_count; - } else if (actual_type->id == ZigTypeIdInt) { - actual_bits = actual_type->data.integral.bit_count; - wanted_bits = wanted_type->data.integral.bit_count; + if (scalar_actual_type->id == ZigTypeIdFloat) { + actual_bits = scalar_actual_type->data.floating.bit_count; + wanted_bits = scalar_wanted_type->data.floating.bit_count; + } else if (scalar_actual_type->id == ZigTypeIdInt) { + actual_bits = scalar_actual_type->data.integral.bit_count; + wanted_bits = scalar_wanted_type->data.integral.bit_count; } else { zig_unreachable(); } - if (actual_type->id == ZigTypeIdInt && want_runtime_safety && ( + if (scalar_actual_type->id == ZigTypeIdInt && want_runtime_safety && ( // negative to unsigned - (!wanted_type->data.integral.is_signed && actual_type->data.integral.is_signed) || + (!scalar_wanted_type->data.integral.is_signed && scalar_actual_type->data.integral.is_signed) || // unsigned would become negative - (wanted_type->data.integral.is_signed && !actual_type->data.integral.is_signed && actual_bits == wanted_bits))) + (scalar_wanted_type->data.integral.is_signed && !scalar_actual_type->data.integral.is_signed && actual_bits == wanted_bits))) { LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, actual_type)); LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntSGE, expr_val, zero, ""); LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "SignCastOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "SignCastFail"); + if (actual_type->id == ZigTypeIdVector) { + ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit); + } LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); LLVMPositionBuilderAtEnd(g->builder, fail_block); @@ -1484,10 +1495,10 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z if (actual_bits == wanted_bits) { return expr_val; } else if (actual_bits < wanted_bits) { - if (actual_type->id == ZigTypeIdFloat) { + if (scalar_actual_type->id == ZigTypeIdFloat) { return LLVMBuildFPExt(g->builder, expr_val, get_llvm_type(g, wanted_type), ""); - } else if (actual_type->id == ZigTypeIdInt) { - if (actual_type->data.integral.is_signed) { + } else if (scalar_actual_type->id == ZigTypeIdInt) { + if (scalar_actual_type->data.integral.is_signed) { return LLVMBuildSExt(g->builder, expr_val, get_llvm_type(g, wanted_type), ""); } else { return LLVMBuildZExt(g->builder, expr_val, get_llvm_type(g, wanted_type), ""); @@ -1496,9 +1507,9 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z zig_unreachable(); } } else if (actual_bits > wanted_bits) { - if (actual_type->id == ZigTypeIdFloat) { + if (scalar_actual_type->id == ZigTypeIdFloat) { return LLVMBuildFPTrunc(g->builder, expr_val, get_llvm_type(g, wanted_type), ""); - } else if (actual_type->id == ZigTypeIdInt) { + } else if (scalar_actual_type->id == ZigTypeIdInt) { if (wanted_bits == 0) { if (!want_runtime_safety) return nullptr; @@ -1510,12 +1521,15 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z return trunc_val; } LLVMValueRef orig_val; - if (wanted_type->data.integral.is_signed) { + if (scalar_wanted_type->data.integral.is_signed) { orig_val = LLVMBuildSExt(g->builder, trunc_val, get_llvm_type(g, actual_type), ""); } else { orig_val = LLVMBuildZExt(g->builder, trunc_val, get_llvm_type(g, actual_type), ""); } LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, orig_val, ""); + if (actual_type->id == ZigTypeIdVector) { + ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit); + } LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenOk"); LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenFail"); LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block); diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index ea9bf6ee8b..d638406a7b 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -86,6 +86,8 @@ enum ConstCastResultId { ConstCastResultIdCV, ConstCastResultIdPtrSentinel, ConstCastResultIdIntShorten, + ConstCastResultIdVectorLength, + ConstCastResultIdVectorChild, }; struct ConstCastOnly; @@ -914,6 +916,7 @@ static bool types_have_same_zig_comptime_repr(CodeGen *codegen, ZigType *expecte if (is_opt_err_set(expected) && is_opt_err_set(actual)) return true; + // XXX: Vectors and arrays are interchangeable at comptime if (expected->id != actual->id) return false; @@ -947,9 +950,11 @@ static bool types_have_same_zig_comptime_repr(CodeGen *codegen, ZigType *expecte case ZigTypeIdErrorUnion: case ZigTypeIdEnum: case ZigTypeIdUnion: - case ZigTypeIdVector: case ZigTypeIdFnFrame: return false; + case ZigTypeIdVector: + return expected->data.vector.len == actual->data.vector.len && + types_have_same_zig_comptime_repr(codegen, expected->data.vector.elem_type, actual->data.vector.elem_type); case ZigTypeIdArray: return expected->data.array.len == actual->data.array.len && expected->data.array.child_type == actual->data.array.child_type && @@ -12190,6 +12195,24 @@ static ConstCastOnly types_match_const_cast_only(IrAnalyze *ira, ZigType *wanted return result; } + if (wanted_type->id == ZigTypeIdVector && actual_type->id == ZigTypeIdVector) { + if (actual_type->data.vector.len != wanted_type->data.vector.len) { + result.id = ConstCastResultIdVectorLength; + return result; + } + + ConstCastOnly child = types_match_const_cast_only(ira, wanted_type->data.vector.elem_type, + actual_type->data.vector.elem_type, source_node, false); + if (child.id == ConstCastResultIdInvalid) + return child; + if (child.id != ConstCastResultIdOk) { + result.id = ConstCastResultIdVectorChild; + return result; + } + + return result; + } + result.id = ConstCastResultIdType; result.data.type_mismatch = heap::c_allocator.allocate_nonzero(1); result.data.type_mismatch->wanted_type = wanted_type; @@ -14306,37 +14329,62 @@ static IrInstGen *ir_analyze_enum_to_union(IrAnalyze *ira, IrInst* source_instr, return ira->codegen->invalid_inst_gen; } +static bool value_numeric_fits_in_type(ZigValue *value, ZigType *type_entry); + static IrInstGen *ir_analyze_widen_or_shorten(IrAnalyze *ira, IrInst* source_instr, IrInstGen *target, ZigType *wanted_type) { - assert(wanted_type->id == ZigTypeIdInt || wanted_type->id == ZigTypeIdFloat); + ZigType *wanted_scalar_type = (target->value->type->id == ZigTypeIdVector) ? + wanted_type->data.vector.elem_type : wanted_type; + + assert(wanted_scalar_type->id == ZigTypeIdInt || wanted_scalar_type->id == ZigTypeIdFloat); if (instr_is_comptime(target)) { ZigValue *val = ir_resolve_const(ira, target, UndefBad); if (!val) return ira->codegen->invalid_inst_gen; - if (wanted_type->id == ZigTypeIdInt) { - if (bigint_cmp_zero(&val->data.x_bigint) == CmpLT && !wanted_type->data.integral.is_signed) { + + if (wanted_scalar_type->id == ZigTypeIdInt) { + if (!wanted_scalar_type->data.integral.is_signed && value_cmp_numeric_val_any(val, CmpLT, nullptr)) { ir_add_error(ira, source_instr, buf_sprintf("attempt to cast negative value to unsigned integer")); return ira->codegen->invalid_inst_gen; } - if (!bigint_fits_in_bits(&val->data.x_bigint, wanted_type->data.integral.bit_count, - wanted_type->data.integral.is_signed)) - { + if (!value_numeric_fits_in_type(val, wanted_scalar_type)) { ir_add_error(ira, source_instr, buf_sprintf("cast from '%s' to '%s' truncates bits", - buf_ptr(&target->value->type->name), buf_ptr(&wanted_type->name))); + buf_ptr(&target->value->type->name), buf_ptr(&wanted_scalar_type->name))); return ira->codegen->invalid_inst_gen; } } + IrInstGen *result = ir_const(ira, source_instr, wanted_type); result->value->type = wanted_type; - if (wanted_type->id == ZigTypeIdInt) { - bigint_init_bigint(&result->value->data.x_bigint, &val->data.x_bigint); + + if (wanted_type->id == ZigTypeIdVector) { + result->value->data.x_array.data.s_none.elements = ira->codegen->pass1_arena->allocate(wanted_type->data.vector.len); + + for (size_t i = 0; i < wanted_type->data.vector.len; i++) { + ZigValue *scalar_dest_value = &result->value->data.x_array.data.s_none.elements[i]; + ZigValue *scalar_src_value = &val->data.x_array.data.s_none.elements[i]; + + scalar_dest_value->type = wanted_scalar_type; + scalar_dest_value->special = ConstValSpecialStatic; + + if (wanted_scalar_type->id == ZigTypeIdInt) { + bigint_init_bigint(&scalar_dest_value->data.x_bigint, &scalar_src_value->data.x_bigint); + } else { + float_init_float(scalar_dest_value, scalar_src_value); + } + } } else { - float_init_float(result->value, val); + if (wanted_type->id == ZigTypeIdInt) { + bigint_init_bigint(&result->value->data.x_bigint, &val->data.x_bigint); + } else { + float_init_float(result->value, val); + } } + return result; } @@ -14779,6 +14827,8 @@ static void report_recursive_error(IrAnalyze *ira, AstNode *source_node, ConstCa actual_signed, actual_type->data.integral.bit_count)); break; } + case ConstCastResultIdVectorLength: // TODO + case ConstCastResultIdVectorChild: // TODO case ConstCastResultIdFnAlign: // TODO case ConstCastResultIdFnVarArgs: // TODO case ConstCastResultIdFnReturnType: // TODO @@ -15462,12 +15512,35 @@ static IrInstGen *ir_analyze_cast(IrAnalyze *ira, IrInst *source_instr, } // @Vector(N,T1) to @Vector(N,T2) - if (actual_type->id == ZigTypeIdVector && wanted_type->id == ZigTypeIdVector) { - if (actual_type->data.vector.len == wanted_type->data.vector.len && - types_match_const_cast_only(ira, wanted_type->data.vector.elem_type, - actual_type->data.vector.elem_type, source_node, false).id == ConstCastResultIdOk) + if (actual_type->id == ZigTypeIdVector && wanted_type->id == ZigTypeIdVector && + actual_type->data.vector.len == wanted_type->data.vector.len) + { + ZigType *scalar_actual_type = actual_type->data.vector.elem_type; + ZigType *scalar_wanted_type = wanted_type->data.vector.elem_type; + + // widening conversion + if (scalar_wanted_type->id == ZigTypeIdInt && + scalar_actual_type->id == ZigTypeIdInt && + scalar_wanted_type->data.integral.is_signed == scalar_actual_type->data.integral.is_signed && + scalar_wanted_type->data.integral.bit_count >= scalar_actual_type->data.integral.bit_count) { - return ir_analyze_bit_cast(ira, source_instr, value, wanted_type); + return ir_analyze_widen_or_shorten(ira, source_instr, value, wanted_type); + } + + // small enough unsigned ints can get casted to large enough signed ints + if (scalar_wanted_type->id == ZigTypeIdInt && scalar_wanted_type->data.integral.is_signed && + scalar_actual_type->id == ZigTypeIdInt && !scalar_actual_type->data.integral.is_signed && + scalar_wanted_type->data.integral.bit_count > scalar_actual_type->data.integral.bit_count) + { + return ir_analyze_widen_or_shorten(ira, source_instr, value, wanted_type); + } + + // float widening conversion + if (scalar_wanted_type->id == ZigTypeIdFloat && + scalar_actual_type->id == ZigTypeIdFloat && + scalar_wanted_type->data.floating.bit_count >= scalar_actual_type->data.floating.bit_count) + { + return ir_analyze_widen_or_shorten(ira, source_instr, value, wanted_type); } } @@ -17728,6 +17801,33 @@ static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) { zig_unreachable(); } +// Returns true if integer `value` can be converted to `type_entry` without +// losing data. +// If `value` is a vector the function returns true if this is valid for every +// element. +static bool value_numeric_fits_in_type(ZigValue *value, ZigType *type_entry) { + assert(value->special == ConstValSpecialStatic); + assert(type_entry->id == ZigTypeIdInt); + + switch (value->type->id) { + case ZigTypeIdComptimeInt: + case ZigTypeIdInt: { + return bigint_fits_in_bits(&value->data.x_bigint, type_entry->data.integral.bit_count, + type_entry->data.integral.is_signed); + } + case ZigTypeIdVector: { + for (size_t i = 0; i < value->type->data.vector.len; i++) { + ZigValue *scalar_value = &value->data.x_array.data.s_none.elements[i]; + const bool result = bigint_fits_in_bits(&scalar_value->data.x_bigint, + type_entry->data.integral.bit_count, type_entry->data.integral.is_signed); + if (!result) return false; + } + return true; + } + default: zig_unreachable(); + } +} + static bool value_cmp_numeric_val(ZigValue *left, Cmp predicate, ZigValue *right, bool any) { assert(left->special == ConstValSpecialStatic); assert(right == nullptr || right->special == ConstValSpecialStatic); @@ -27154,8 +27254,12 @@ static IrInstGen *ir_analyze_instruction_int_cast(IrAnalyze *ira, IrInstSrcIntCa if (type_is_invalid(dest_type)) return ira->codegen->invalid_inst_gen; - if (dest_type->id != ZigTypeIdInt && dest_type->id != ZigTypeIdComptimeInt) { - ir_add_error(ira, &instruction->dest_type->base, buf_sprintf("expected integer type, found '%s'", buf_ptr(&dest_type->name))); + ZigType *scalar_dest_type = (dest_type->id == ZigTypeIdVector) ? + dest_type->data.vector.elem_type : dest_type; + + if (scalar_dest_type->id != ZigTypeIdInt && scalar_dest_type->id != ZigTypeIdComptimeInt) { + ir_add_error(ira, &instruction->dest_type->base, + buf_sprintf("expected integer type, found '%s'", buf_ptr(&scalar_dest_type->name))); return ira->codegen->invalid_inst_gen; } @@ -27163,13 +27267,16 @@ static IrInstGen *ir_analyze_instruction_int_cast(IrAnalyze *ira, IrInstSrcIntCa if (type_is_invalid(target->value->type)) return ira->codegen->invalid_inst_gen; - if (target->value->type->id != ZigTypeIdInt && target->value->type->id != ZigTypeIdComptimeInt) { + ZigType *scalar_target_type = (target->value->type->id == ZigTypeIdVector) ? + target->value->type->data.vector.elem_type : target->value->type; + + if (scalar_target_type->id != ZigTypeIdInt && scalar_target_type->id != ZigTypeIdComptimeInt) { ir_add_error(ira, &instruction->target->base, buf_sprintf("expected integer type, found '%s'", - buf_ptr(&target->value->type->name))); + buf_ptr(&scalar_target_type->name))); return ira->codegen->invalid_inst_gen; } - if (instr_is_comptime(target) || dest_type->id == ZigTypeIdComptimeInt) { + if (scalar_dest_type->id == ZigTypeIdComptimeInt) { ZigValue *val = ir_resolve_const(ira, target, UndefBad); if (val == nullptr) return ira->codegen->invalid_inst_gen; @@ -27222,6 +27329,7 @@ static IrInstGen *ir_analyze_instruction_float_cast(IrAnalyze *ira, IrInstSrcFlo if (val == nullptr) return ira->codegen->invalid_inst_gen; + // XXX: This will trigger an assertion failure if dest_type is comptime_float return ir_analyze_widen_or_shorten(ira, &instruction->target->base, target, dest_type); } diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 062a357a93..1d3a2a6ea1 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -2944,7 +2944,6 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { "tmp.zig:4:18: error: expected type 'fn(i32) void', found 'fn(bool) void", "tmp.zig:4:18: note: parameter 0: 'bool' cannot cast into 'i32'", }); - cases.add("cast negative value to unsigned integer", \\comptime { \\ const value: i32 = -1; @@ -2955,7 +2954,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { \\ const unsigned: u32 = value; \\} , &[_][]const u8{ - "tmp.zig:3:36: error: cannot cast negative value -1 to unsigned integer type 'u32'", + "tmp.zig:3:22: error: attempt to cast negative value to unsigned integer", "tmp.zig:7:27: error: cannot cast negative value -1 to unsigned integer type 'u32'", }); @@ -2977,7 +2976,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { \\ var unsigned: u64 = signed; \\} , &[_][]const u8{ - "tmp.zig:3:31: error: integer value 300 cannot be coerced to type 'u8'", + "tmp.zig:3:18: error: cast from 'u16' to 'u8' truncates bits", "tmp.zig:7:22: error: integer value 300 cannot be coerced to type 'u8'", "tmp.zig:11:20: error: expected type 'u8', found 'u16'", "tmp.zig:11:20: note: unsigned 8-bit int cannot represent all possible unsigned 16-bit values", diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig index 8d962b088e..a0b4a555a5 100644 --- a/test/runtime_safety.zig +++ b/test/runtime_safety.zig @@ -70,6 +70,36 @@ pub fn addCases(cases: *tests.CompareOutputContext) void { ); } + cases.addRuntimeSafety("truncating vector cast", + \\const std = @import("std"); + \\const V = @import("std").meta.Vector; + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ if (std.mem.eql(u8, message, "integer cast truncated bits")) { + \\ std.process.exit(126); // good + \\ } + \\ std.process.exit(0); // test failed + \\} + \\pub fn main() void { + \\ var x = @splat(4, @as(u32, 0xdeadbeef)); + \\ var y = @intCast(V(4, u16), x); + \\} + ); + + cases.addRuntimeSafety("unsigned-signed vector cast", + \\const std = @import("std"); + \\const V = @import("std").meta.Vector; + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ if (std.mem.eql(u8, message, "attempt to cast negative value to unsigned integer")) { + \\ std.process.exit(126); // good + \\ } + \\ std.process.exit(0); // test failed + \\} + \\pub fn main() void { + \\ var x = @splat(4, @as(u32, 0x80000000)); + \\ var y = @intCast(V(4, i32), x); + \\} + ); + cases.addRuntimeSafety("shift left by huge amount", \\const std = @import("std"); \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { diff --git a/test/stage1/behavior/cast.zig b/test/stage1/behavior/cast.zig index 0eb00512b9..d39b20a7e1 100644 --- a/test/stage1/behavior/cast.zig +++ b/test/stage1/behavior/cast.zig @@ -2,6 +2,7 @@ const std = @import("std"); const expect = std.testing.expect; const mem = std.mem; const maxInt = std.math.maxInt; +const Vector = std.meta.Vector; test "int to ptr cast" { const x = @as(usize, 13); @@ -364,6 +365,43 @@ test "@floatCast comptime_int and comptime_float" { } } +test "vector casts" { + const S = struct { + fn doTheTest() void { + // Upcast (implicit, equivalent to @intCast) + var up0: Vector(2, u8) = [_]u8{ 0x55, 0xaa }; + var up1 = @as(Vector(2, u16), up0); + var up2 = @as(Vector(2, u32), up0); + var up3 = @as(Vector(2, u64), up0); + // Downcast (safety-checked) + var down0 = up3; + var down1 = @intCast(Vector(2, u32), down0); + var down2 = @intCast(Vector(2, u16), down0); + var down3 = @intCast(Vector(2, u8), down0); + + expect(mem.eql(u16, &@as([2]u16, up1), &[2]u16{ 0x55, 0xaa })); + expect(mem.eql(u32, &@as([2]u32, up2), &[2]u32{ 0x55, 0xaa })); + expect(mem.eql(u64, &@as([2]u64, up3), &[2]u64{ 0x55, 0xaa })); + + expect(mem.eql(u32, &@as([2]u32, down1), &[2]u32{ 0x55, 0xaa })); + expect(mem.eql(u16, &@as([2]u16, down2), &[2]u16{ 0x55, 0xaa })); + expect(mem.eql(u8, &@as([2]u8, down3), &[2]u8{ 0x55, 0xaa })); + } + + fn doTheTestFloat() void { + var vec = @splat(2, @as(f32, 1234.0)); + var wider: Vector(2, f64) = vec; + expect(wider[0] == 1234.0); + expect(wider[1] == 1234.0); + } + }; + + S.doTheTest(); + comptime S.doTheTest(); + S.doTheTestFloat(); + comptime S.doTheTestFloat(); +} + test "comptime_int @intToFloat" { { const result = @intToFloat(f16, 1234);