From 20b1491e6bd87b10e822282c5867604d634973a1 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 4 May 2017 10:18:01 -0400 Subject: [PATCH] implement while for nullables and error unions See #357 --- src/ast_render.cpp | 19 +++- src/ir.cpp | 212 +++++++++++++++++++++++++++++++++------- src/parser.cpp | 2 +- std/linked_list.zig | 65 ++++-------- test/cases/while.zig | 80 ++++++++++++++- test/compile_errors.zig | 48 +++++++++ 6 files changed, 336 insertions(+), 90 deletions(-) diff --git a/src/ast_render.cpp b/src/ast_render.cpp index 902891b5a2..348d825332 100644 --- a/src/ast_render.cpp +++ b/src/ast_render.cpp @@ -725,12 +725,23 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) { const char *inline_str = node->data.while_expr.is_inline ? "inline " : ""; fprintf(ar->f, "%swhile (", inline_str); render_node_grouped(ar, node->data.while_expr.condition); - if (node->data.while_expr.continue_expr) { - fprintf(ar->f, "; "); - render_node_grouped(ar, node->data.while_expr.continue_expr); - } fprintf(ar->f, ") "); + if (node->data.while_expr.var_symbol) { + fprintf(ar->f, "|%s| ", buf_ptr(node->data.while_expr.var_symbol)); + } + if (node->data.while_expr.continue_expr) { + fprintf(ar->f, ": ("); + render_node_grouped(ar, node->data.while_expr.continue_expr); + fprintf(ar->f, ") "); + } render_node_grouped(ar, node->data.while_expr.body); + if (node->data.while_expr.else_node) { + fprintf(ar->f, " else "); + if (node->data.while_expr.err_symbol) { + fprintf(ar->f, "|%s| ", buf_ptr(node->data.while_expr.err_symbol)); + } + render_node_grouped(ar, node->data.while_expr.else_node); + } break; } case NodeTypeThisLiteral: diff --git a/src/ir.cpp b/src/ir.cpp index def7469a99..3b541b7731 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -4658,51 +4658,197 @@ static IrInstruction *ir_gen_while_expr(IrBuilder *irb, Scope *scope, AstNode *n assert(node->type == NodeTypeWhileExpr); AstNode *continue_expr_node = node->data.while_expr.continue_expr; + AstNode *else_node = node->data.while_expr.else_node; IrBasicBlock *cond_block = ir_build_basic_block(irb, scope, "WhileCond"); IrBasicBlock *body_block = ir_build_basic_block(irb, scope, "WhileBody"); IrBasicBlock *continue_block = continue_expr_node ? ir_build_basic_block(irb, scope, "WhileContinue") : cond_block; IrBasicBlock *end_block = ir_build_basic_block(irb, scope, "WhileEnd"); + IrBasicBlock *else_block = else_node ? + ir_build_basic_block(irb, scope, "WhileElse") : end_block; IrInstruction *is_comptime = ir_build_const_bool(irb, scope, node, ir_should_inline(irb->exec, scope) || node->data.while_expr.is_inline); ir_build_br(irb, scope, node, cond_block, is_comptime); - if (continue_expr_node) { - ir_set_cursor_at_end(irb, continue_block); - IrInstruction *expr_result = ir_gen_node(irb, continue_expr_node, scope); - if (expr_result == irb->codegen->invalid_instruction) - return expr_result; - if (!instr_is_unreachable(expr_result)) - ir_mark_gen(ir_build_br(irb, scope, node, cond_block, is_comptime)); + Buf *var_symbol = node->data.while_expr.var_symbol; + Buf *err_symbol = node->data.while_expr.err_symbol; + if (err_symbol != nullptr) { + ir_set_cursor_at_end(irb, cond_block); + + Scope *payload_scope; + AstNode *symbol_node = node; // TODO make more accurate + VariableTableEntry *payload_var; + if (var_symbol) { + // TODO make it an error to write to payload variable + payload_var = ir_create_var(irb, symbol_node, scope, var_symbol, + true, false, false, is_comptime); + payload_scope = payload_var->child_scope; + } else { + payload_scope = scope; + } + IrInstruction *err_val_ptr = ir_gen_node_extra(irb, node->data.while_expr.condition, scope, LVAL_PTR); + if (err_val_ptr == irb->codegen->invalid_instruction) + return err_val_ptr; + IrInstruction *err_val = ir_build_load_ptr(irb, scope, node->data.while_expr.condition, err_val_ptr); + IrInstruction *is_err = ir_build_test_err(irb, scope, node->data.while_expr.condition, err_val); + if (!instr_is_unreachable(is_err)) { + ir_mark_gen(ir_build_cond_br(irb, scope, node->data.while_expr.condition, is_err, + else_block, body_block, is_comptime)); + } + + ir_set_cursor_at_end(irb, body_block); + if (var_symbol) { + IrInstruction *var_ptr_value = ir_build_unwrap_err_payload(irb, payload_scope, symbol_node, + err_val_ptr, false); + IrInstruction *var_value = node->data.while_expr.var_is_ptr ? + var_ptr_value : ir_build_load_ptr(irb, payload_scope, symbol_node, var_ptr_value); + ir_build_var_decl(irb, payload_scope, symbol_node, payload_var, nullptr, var_value); + } + LoopStackItem *loop_stack_item = irb->loop_stack.add_one(); + loop_stack_item->break_block = end_block; + loop_stack_item->continue_block = continue_block; + loop_stack_item->is_comptime = is_comptime; + IrInstruction *body_result = ir_gen_node(irb, node->data.while_expr.body, payload_scope); + if (body_result == irb->codegen->invalid_instruction) + return body_result; + irb->loop_stack.pop(); + + if (!instr_is_unreachable(body_result)) + ir_mark_gen(ir_build_br(irb, payload_scope, node, continue_block, is_comptime)); + + if (continue_expr_node) { + ir_set_cursor_at_end(irb, continue_block); + IrInstruction *expr_result = ir_gen_node(irb, continue_expr_node, payload_scope); + if (expr_result == irb->codegen->invalid_instruction) + return expr_result; + if (!instr_is_unreachable(expr_result)) + ir_mark_gen(ir_build_br(irb, payload_scope, node, cond_block, is_comptime)); + } + + if (else_node) { + ir_set_cursor_at_end(irb, else_block); + + // TODO make it an error to write to error variable + AstNode *err_symbol_node = else_node; // TODO make more accurate + VariableTableEntry *err_var = ir_create_var(irb, err_symbol_node, scope, err_symbol, + true, false, false, is_comptime); + Scope *err_scope = err_var->child_scope; + IrInstruction *err_var_value = ir_build_unwrap_err_code(irb, err_scope, err_symbol_node, err_val_ptr); + ir_build_var_decl(irb, err_scope, symbol_node, err_var, nullptr, err_var_value); + + IrInstruction *else_result = ir_gen_node(irb, else_node, err_scope); + if (else_result == irb->codegen->invalid_instruction) + return else_result; + if (!instr_is_unreachable(else_result)) + ir_mark_gen(ir_build_br(irb, scope, node, end_block, is_comptime)); + } + + ir_set_cursor_at_end(irb, end_block); + return ir_build_const_void(irb, scope, node); + } else if (var_symbol != nullptr) { + ir_set_cursor_at_end(irb, cond_block); + // TODO make it an error to write to payload variable + AstNode *symbol_node = node; // TODO make more accurate + VariableTableEntry *payload_var = ir_create_var(irb, symbol_node, scope, var_symbol, + true, false, false, is_comptime); + Scope *child_scope = payload_var->child_scope; + IrInstruction *maybe_val_ptr = ir_gen_node_extra(irb, node->data.while_expr.condition, scope, LVAL_PTR); + if (maybe_val_ptr == irb->codegen->invalid_instruction) + return maybe_val_ptr; + IrInstruction *maybe_val = ir_build_load_ptr(irb, scope, node->data.while_expr.condition, maybe_val_ptr); + IrInstruction *is_non_null = ir_build_test_nonnull(irb, scope, node->data.while_expr.condition, maybe_val); + if (!instr_is_unreachable(is_non_null)) { + ir_mark_gen(ir_build_cond_br(irb, scope, node->data.while_expr.condition, is_non_null, + body_block, else_block, is_comptime)); + } + + ir_set_cursor_at_end(irb, body_block); + IrInstruction *var_ptr_value = ir_build_unwrap_maybe(irb, child_scope, symbol_node, maybe_val_ptr, false); + IrInstruction *var_value = node->data.while_expr.var_is_ptr ? + var_ptr_value : ir_build_load_ptr(irb, child_scope, symbol_node, var_ptr_value); + ir_build_var_decl(irb, child_scope, symbol_node, payload_var, nullptr, var_value); + LoopStackItem *loop_stack_item = irb->loop_stack.add_one(); + loop_stack_item->break_block = end_block; + loop_stack_item->continue_block = continue_block; + loop_stack_item->is_comptime = is_comptime; + IrInstruction *body_result = ir_gen_node(irb, node->data.while_expr.body, child_scope); + if (body_result == irb->codegen->invalid_instruction) + return body_result; + irb->loop_stack.pop(); + + if (!instr_is_unreachable(body_result)) + ir_mark_gen(ir_build_br(irb, child_scope, node, continue_block, is_comptime)); + + if (continue_expr_node) { + ir_set_cursor_at_end(irb, continue_block); + IrInstruction *expr_result = ir_gen_node(irb, continue_expr_node, child_scope); + if (expr_result == irb->codegen->invalid_instruction) + return expr_result; + if (!instr_is_unreachable(expr_result)) + ir_mark_gen(ir_build_br(irb, child_scope, node, cond_block, is_comptime)); + } + + if (else_node) { + ir_set_cursor_at_end(irb, else_block); + + IrInstruction *else_result = ir_gen_node(irb, else_node, scope); + if (else_result == irb->codegen->invalid_instruction) + return else_result; + if (!instr_is_unreachable(else_result)) + ir_mark_gen(ir_build_br(irb, scope, node, end_block, is_comptime)); + } + + ir_set_cursor_at_end(irb, end_block); + return ir_build_const_void(irb, scope, node); + } else { + if (continue_expr_node) { + ir_set_cursor_at_end(irb, continue_block); + IrInstruction *expr_result = ir_gen_node(irb, continue_expr_node, scope); + if (expr_result == irb->codegen->invalid_instruction) + return expr_result; + if (!instr_is_unreachable(expr_result)) + ir_mark_gen(ir_build_br(irb, scope, node, cond_block, is_comptime)); + } + + ir_set_cursor_at_end(irb, cond_block); + IrInstruction *cond_val = ir_gen_node(irb, node->data.while_expr.condition, scope); + if (cond_val == irb->codegen->invalid_instruction) + return cond_val; + if (!instr_is_unreachable(cond_val)) { + ir_mark_gen(ir_build_cond_br(irb, scope, node->data.while_expr.condition, cond_val, + body_block, else_block, is_comptime)); + } + + ir_set_cursor_at_end(irb, body_block); + + LoopStackItem *loop_stack_item = irb->loop_stack.add_one(); + loop_stack_item->break_block = end_block; + loop_stack_item->continue_block = continue_block; + loop_stack_item->is_comptime = is_comptime; + IrInstruction *body_result = ir_gen_node(irb, node->data.while_expr.body, scope); + if (body_result == irb->codegen->invalid_instruction) + return body_result; + irb->loop_stack.pop(); + + if (!instr_is_unreachable(body_result)) + ir_mark_gen(ir_build_br(irb, scope, node, continue_block, is_comptime)); + + if (else_node) { + ir_set_cursor_at_end(irb, else_block); + + IrInstruction *else_result = ir_gen_node(irb, else_node, scope); + if (else_result == irb->codegen->invalid_instruction) + return else_result; + if (!instr_is_unreachable(else_result)) + ir_mark_gen(ir_build_br(irb, scope, node, end_block, is_comptime)); + } + + ir_set_cursor_at_end(irb, end_block); + + return ir_build_const_void(irb, scope, node); } - - ir_set_cursor_at_end(irb, cond_block); - IrInstruction *cond_val = ir_gen_node(irb, node->data.while_expr.condition, scope); - if (cond_val == irb->codegen->invalid_instruction) - return cond_val; - if (!instr_is_unreachable(cond_val)) { - ir_mark_gen(ir_build_cond_br(irb, scope, node->data.while_expr.condition, cond_val, - body_block, end_block, is_comptime)); - } - - ir_set_cursor_at_end(irb, body_block); - - LoopStackItem *loop_stack_item = irb->loop_stack.add_one(); - loop_stack_item->break_block = end_block; - loop_stack_item->continue_block = continue_block; - loop_stack_item->is_comptime = is_comptime; - IrInstruction *body_result = ir_gen_node(irb, node->data.while_expr.body, scope); - if (body_result == irb->codegen->invalid_instruction) - return body_result; - irb->loop_stack.pop(); - - if (!instr_is_unreachable(body_result)) - ir_mark_gen(ir_build_br(irb, scope, node, continue_block, is_comptime)); - ir_set_cursor_at_end(irb, end_block); - - return ir_build_const_void(irb, scope, node); } static IrInstruction *ir_gen_for_expr(IrBuilder *irb, Scope *parent_scope, AstNode *node) { diff --git a/src/parser.cpp b/src/parser.cpp index a8c26f3c0e..753dc6e67c 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -1654,7 +1654,7 @@ static AstNode *ast_parse_while_expr(ParseContext *pc, size_t *token_index, bool ast_eat_token(pc, token_index, TokenIdBinOr); } - node->data.while_expr.body = ast_parse_block_or_expression(pc, token_index, true); + node->data.while_expr.else_node = ast_parse_block_or_expression(pc, token_index, true); } return node; diff --git a/std/linked_list.zig b/std/linked_list.zig index 1f43669ad8..1b933d8058 100644 --- a/std/linked_list.zig +++ b/std/linked_list.zig @@ -187,43 +187,6 @@ pub fn LinkedList(comptime T: type) -> type { }; return node; } - - /// Iterate through the elements of the list. - /// - /// Returns: - /// A list iterator with a next() method. - pub fn iterate(list: &List) -> List.Iterator(false) { - List.Iterator(false) { - .node = list.first, - } - } - - /// Iterate through the elements of the list backwards. - /// - /// Returns: - /// A list iterator with a next() method. - pub fn iterateBackwards(list: &List) -> List.Iterator(true) { - List.Iterator(true) { - .node = list.last, - } - } - - /// Abstract iteration over a linked list. - pub fn Iterator(comptime backwards: bool) -> type { - struct { - const It = this; - - node: ?&Node, - - /// Return the next element of the list, until the end. - /// When no more elements are available, return null. - pub fn next(it: &It) -> ?&Node { - const current = it.node ?? return null; - it.node = if (backwards) current.prev else current.next; - return current; - } - } - } } } @@ -249,16 +212,24 @@ test "basic linked list test" { list.insertBefore(five, four); // {1, 2, 4, 5} list.insertAfter(two, three); // {1, 2, 3, 4, 5} - // Traverse the list forwards and backwards. - var it = list.iterate(); - var it_reverse = list.iterateBackwards(); - var index: u32 = 1; - while (true) { - const node = it.next() ?? break; - const node_reverse = it_reverse.next() ?? break; - assert (node.data == index); - assert (node_reverse.data == (6 - index)); - index += 1; + // traverse forwards + { + var it = list.first; + var index: u32 = 1; + while (it) |node| : (it = node.next) { + assert(node.data == index); + index += 1; + } + } + + // traverse backwards + { + var it = list.last; + var index: u32 = 1; + while (it) |node| : (it = node.prev) { + assert(node.data == (6 - index)); + index += 1; + } } var first = list.popFirst(); // {2, 3, 4, 5} diff --git a/test/cases/while.zig b/test/cases/while.zig index 70a47db0ae..476c29dd78 100644 --- a/test/cases/while.zig +++ b/test/cases/while.zig @@ -1,6 +1,6 @@ const assert = @import("std").debug.assert; -test "whileLoop" { +test "while loop" { var i : i32 = 0; while (i < 4) { i += 1; @@ -16,7 +16,7 @@ fn whileLoop2() -> i32 { return 1; } } -test "staticEvalWhile" { +test "static eval while" { assert(static_eval_while_number == 1); } const static_eval_while_number = staticWhileLoop1(); @@ -29,7 +29,7 @@ fn staticWhileLoop2() -> i32 { } } -test "continueAndBreak" { +test "continue and break" { runContinueAndBreakTest(); assert(continue_and_break_counter == 8); } @@ -47,7 +47,7 @@ fn runContinueAndBreakTest() { assert(i == 4); } -test "returnWithImplicitCastFromWhileLoop" { +test "return with implicit cast from while loop" { %%returnWithImplicitCastFromWhileLoopTest(); } fn returnWithImplicitCastFromWhileLoopTest() -> %void { @@ -56,7 +56,7 @@ fn returnWithImplicitCastFromWhileLoopTest() -> %void { } } -test "whileWithContinueExpr" { +test "while with continue expression" { var sum: i32 = 0; {var i: i32 = 0; while (i < 10) : (i += 1) { if (i == 5) continue; @@ -64,3 +64,73 @@ test "whileWithContinueExpr" { }} assert(sum == 40); } + +test "while with else" { + var sum: i32 = 0; + var i: i32 = 0; + var got_else: i32 = 0; + while (i < 10) : (i += 1) { + sum += 1; + } else { + got_else += 1; + } + assert(sum == 10); + assert(got_else == 1); +} + +test "while with nullable as condition" { + numbers_left = 10; + var sum: i32 = 0; + while (getNumberOrNull()) |value| { + sum += value; + } + assert(sum == 45); +} + +test "while with nullable as condition with else" { + numbers_left = 10; + var sum: i32 = 0; + var got_else: i32 = 0; + while (getNumberOrNull()) |value| { + sum += value; + assert(got_else == 0); + } else { + got_else += 1; + } + assert(sum == 45); + assert(got_else == 1); +} + +test "while with error union condition" { + numbers_left = 10; + var sum: i32 = 0; + var got_else: i32 = 0; + while (getNumberOrErr()) |value| { + sum += value; + } else |err| { + assert(err == error.OutOfNumbers); + got_else += 1; + } + assert(sum == 45); + assert(got_else == 1); +} + +var numbers_left: i32 = undefined; +error OutOfNumbers; +fn getNumberOrErr() -> %i32 { + return if (numbers_left == 0) { + error.OutOfNumbers + } else { + numbers_left -= 1; + numbers_left + }; +} +fn getNumberOrNull() -> ?i32 { + return if (numbers_left == 0) { + null + } else { + numbers_left -= 1; + numbers_left + }; +} + diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 6bdbeebdc8..f6ff365451 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -1636,4 +1636,52 @@ pub fn addCases(cases: &tests.CompileErrorContext) { , ".tmp_source.zig:9:17: error: redefinition of 'Self'", ".tmp_source.zig:5:9: note: previous definition is here"); + + cases.add("while expected bool, got nullable", + \\export fn foo() { + \\ while (bar()) {} + \\} + \\fn bar() -> ?i32 { 1 } + , + ".tmp_source.zig:2:15: error: expected type 'bool', found '?i32'"); + + cases.add("while expected bool, got error union", + \\export fn foo() { + \\ while (bar()) {} + \\} + \\fn bar() -> %i32 { 1 } + , + ".tmp_source.zig:2:15: error: expected type 'bool', found '%i32'"); + + cases.add("while expected nullable, got bool", + \\export fn foo() { + \\ while (bar()) |x| {} + \\} + \\fn bar() -> bool { true } + , + ".tmp_source.zig:2:15: error: expected nullable type, found 'bool'"); + + cases.add("while expected nullable, got error union", + \\export fn foo() { + \\ while (bar()) |x| {} + \\} + \\fn bar() -> %i32 { 1 } + , + ".tmp_source.zig:2:15: error: expected nullable type, found '%i32'"); + + cases.add("while expected error union, got bool", + \\export fn foo() { + \\ while (bar()) |x| {} else |err| {} + \\} + \\fn bar() -> bool { true } + , + ".tmp_source.zig:2:15: error: expected error union type, found 'bool'"); + + cases.add("while expected error union, got nullable", + \\export fn foo() { + \\ while (bar()) |x| {} else |err| {} + \\} + \\fn bar() -> ?i32 { 1 } + , + ".tmp_source.zig:2:15: error: expected error union type, found '?i32'"); }