diff --git a/std/zig/ast.zig b/std/zig/ast.zig index 3466a24d28..3892812882 100644 --- a/std/zig/ast.zig +++ b/std/zig/ast.zig @@ -102,7 +102,7 @@ pub const NodeFnProto = struct { fn_token: Token, name_token: ?Token, params: ArrayList(&Node), - return_type: &Node, + return_type: ReturnType, var_args_token: ?Token, extern_token: ?Token, inline_token: ?Token, @@ -111,6 +111,12 @@ pub const NodeFnProto = struct { lib_name: ?&Node, // populated if this is an extern declaration align_expr: ?&Node, // populated if align(A) is present + pub const ReturnType = union(enum) { + Explicit: &Node, + Infer, + InferErrorSet: &Node, + }; + pub fn iterate(self: &NodeFnProto, index: usize) ?&Node { var i = index; @@ -119,8 +125,18 @@ pub const NodeFnProto = struct { i -= 1; } - if (i < 1) return self.return_type; - i -= 1; + switch (self.return_type) { + // TODO allow this and next prong to share bodies since the types are the same + ReturnType.Explicit => |node| { + if (i < 1) return node; + i -= 1; + }, + ReturnType.InferErrorSet => |node| { + if (i < 1) return node; + i -= 1; + }, + ReturnType.Infer => {}, + } if (self.align_expr) |align_expr| { if (i < 1) return align_expr; diff --git a/std/zig/parser.zig b/std/zig/parser.zig index 079a331c6d..0c24fe1410 100644 --- a/std/zig/parser.zig +++ b/std/zig/parser.zig @@ -87,6 +87,7 @@ pub const Parser = struct { ExpectToken: @TagType(Token.Id), FnProto: &ast.NodeFnProto, FnProtoAlign: &ast.NodeFnProto, + FnProtoReturnType: &ast.NodeFnProto, ParamDecl: &ast.NodeFnProto, ParamDeclComma, FnDef: &ast.NodeFnProto, @@ -178,7 +179,7 @@ pub const Parser = struct { stack.append(State.TopLevel) catch unreachable; // TODO shouldn't need these casts const fn_proto = try self.createAttachFnProto(arena, &root_node.decls, token, - ctx.extern_token, (?Token)(null), (?Token)(null), (?Token)(null)); + ctx.extern_token, (?Token)(null), ctx.visib_token, (?Token)(null)); try stack.append(State { .FnDef = fn_proto }); try stack.append(State { .FnProto = fn_proto }); continue; @@ -466,11 +467,37 @@ pub const Parser = struct { } self.putBackToken(token); stack.append(State { - .TypeExpr = DestPtr {.Field = &fn_proto.return_type}, + .FnProtoReturnType = fn_proto, }) catch unreachable; continue; }, + State.FnProtoReturnType => |fn_proto| { + const token = self.getNextToken(); + switch (token.id) { + Token.Id.Keyword_var => { + fn_proto.return_type = ast.NodeFnProto.ReturnType.Infer; + }, + Token.Id.Bang => { + fn_proto.return_type = ast.NodeFnProto.ReturnType { .InferErrorSet = undefined }; + stack.append(State { + .TypeExpr = DestPtr {.Field = &fn_proto.return_type.InferErrorSet}, + }) catch unreachable; + }, + else => { + self.putBackToken(token); + fn_proto.return_type = ast.NodeFnProto.ReturnType { .Explicit = undefined }; + stack.append(State { + .TypeExpr = DestPtr {.Field = &fn_proto.return_type.Explicit}, + }) catch unreachable; + }, + } + if (token.id == Token.Id.Keyword_align) { + @panic("TODO fn proto align"); + } + continue; + }, + State.ParamDecl => |fn_proto| { var token = self.getNextToken(); if (token.id == Token.Id.RParen) { @@ -977,19 +1004,23 @@ pub const Parser = struct { }, ast.Node.Id.Block => { const block = @fieldParentPtr(ast.NodeBlock, "base", base); - try stream.write("{"); - try stack.append(RenderState { .Text = "}"}); - try stack.append(RenderState.PrintIndent); - try stack.append(RenderState { .Indent = indent}); - try stack.append(RenderState { .Text = "\n"}); - var i = block.statements.len; - while (i != 0) { - i -= 1; - const statement_node = block.statements.items[i]; - try stack.append(RenderState { .Statement = statement_node}); + if (block.statements.len == 0) { + try stream.write("{}"); + } else { + try stream.write("{"); + try stack.append(RenderState { .Text = "}"}); try stack.append(RenderState.PrintIndent); - try stack.append(RenderState { .Indent = indent + indent_delta}); - try stack.append(RenderState { .Text = "\n" }); + try stack.append(RenderState { .Indent = indent}); + try stack.append(RenderState { .Text = "\n"}); + var i = block.statements.len; + while (i != 0) { + i -= 1; + const statement_node = block.statements.items[i]; + try stack.append(RenderState { .Statement = statement_node}); + try stack.append(RenderState.PrintIndent); + try stack.append(RenderState { .Indent = indent + indent_delta}); + try stack.append(RenderState { .Text = "\n" }); + } } }, ast.Node.Id.InfixOp => { @@ -1071,7 +1102,18 @@ pub const Parser = struct { try stack.append(RenderState { .Expression = body_node}); try stack.append(RenderState { .Text = " "}); } - try stack.append(RenderState { .Expression = fn_proto.return_type}); + switch (fn_proto.return_type) { + ast.NodeFnProto.ReturnType.Explicit => |node| { + try stack.append(RenderState { .Expression = node}); + }, + ast.NodeFnProto.ReturnType.Infer => { + try stream.print("var"); + }, + ast.NodeFnProto.ReturnType.InferErrorSet => |node| { + try stream.print("!"); + try stack.append(RenderState { .Expression = node}); + }, + } }, RenderState.Statement => |base| { switch (base.id) { @@ -1169,6 +1211,13 @@ fn testCanonical(source: []const u8) !void { } test "zig fmt" { + try testCanonical( + \\pub fn main() !void {} + \\pub fn main() var {} + \\pub fn main() i32 {} + \\ + ); + try testCanonical( \\const std = @import("std"); \\const std = @import();