diff --git a/src/main.zig b/src/main.zig index bc4f209b45..be227a2895 100644 --- a/src/main.zig +++ b/src/main.zig @@ -2637,6 +2637,50 @@ fn argvCmd(allocator: *Allocator, argv: []const []const u8) ![]u8 { return cmd.toOwnedSlice(); } +fn readSourceFileToEndAlloc(allocator: *mem.Allocator, input: *const fs.File, size_hint: ?usize) ![]const u8 { + const source_code = input.readToEndAllocOptions( + allocator, + max_src_size, + size_hint, + @alignOf(u16), + null, + ) catch |err| switch (err) { + error.ConnectionResetByPeer => unreachable, + error.ConnectionTimedOut => unreachable, + error.NotOpenForReading => unreachable, + else => |e| return e, + }; + errdefer allocator.free(source_code); + + // Detect unsupported file types with their Byte Order Mark + const unsupported_boms = [_][]const u8{ + "\xff\xfe\x00\x00", // UTF-32 little endian + "\xfe\xff\x00\x00", // UTF-32 big endian + "\xfe\xff", // UTF-16 big endian + }; + for (unsupported_boms) |bom| { + if (mem.startsWith(u8, source_code, bom)) { + return error.UnsupportedEncoding; + } + } + + // If the file starts with a UTF-16 little endian BOM, translate it to UTF-8 + if (mem.startsWith(u8, source_code, "\xff\xfe")) { + const source_code_utf16_le = mem.bytesAsSlice(u16, source_code); + const source_code_utf8 = std.unicode.utf16leToUtf8Alloc(allocator, source_code_utf16_le) catch |err| switch (err) { + error.DanglingSurrogateHalf => error.UnsupportedEncoding, + error.ExpectedSecondSurrogateHalf => error.UnsupportedEncoding, + error.UnexpectedSecondSurrogateHalf => error.UnsupportedEncoding, + else => |e| return e, + }; + + allocator.free(source_code); + return source_code_utf8; + } + + return source_code; +} + pub const usage_fmt = \\Usage: zig fmt [file]... \\ @@ -2708,9 +2752,10 @@ pub fn cmdFmt(gpa: *Allocator, args: []const []const u8) !void { fatal("cannot use --stdin with positional arguments", .{}); } - const stdin = io.getStdIn().reader(); - - const source_code = try stdin.readAllAlloc(gpa, max_src_size); + const stdin = io.getStdIn(); + const source_code = readSourceFileToEndAlloc(gpa, &stdin, null) catch |err| { + fatal("unable to read stdin: {s}", .{err}); + }; defer gpa.free(source_code); var tree = std.zig.parse(gpa, source_code) catch |err| { @@ -2785,6 +2830,7 @@ const FmtError = error{ EndOfStream, Unseekable, NotOpenForWriting, + UnsupportedEncoding, } || fs.File.OpenError; fn fmtPath(fmt: *Fmt, file_path: []const u8, check_mode: bool, dir: fs.Dir, sub_path: []const u8) FmtError!void { @@ -2850,21 +2896,15 @@ fn fmtPathFile( if (stat.kind == .Directory) return error.IsDir; - const source_code = source_file.readToEndAllocOptions( + const source_code = try readSourceFileToEndAlloc( fmt.gpa, - max_src_size, + &source_file, std.math.cast(usize, stat.size) catch return error.FileTooBig, - @alignOf(u8), - null, - ) catch |err| switch (err) { - error.ConnectionResetByPeer => unreachable, - error.ConnectionTimedOut => unreachable, - error.NotOpenForReading => unreachable, - else => |e| return e, - }; + ); + defer fmt.gpa.free(source_code); + source_file.close(); file_closed = true; - defer fmt.gpa.free(source_code); // Add to set after no longer possible to get error.IsDir. if (try fmt.seen.fetchPut(stat.inode, {})) |_| return; diff --git a/test/cli.zig b/test/cli.zig index c0702fa54c..dedea67a59 100644 --- a/test/cli.zig +++ b/test/cli.zig @@ -28,6 +28,8 @@ pub fn main() !void { const zig_exe = try fs.path.resolve(a, &[_][]const u8{zig_exe_rel}); const dir_path = try fs.path.join(a, &[_][]const u8{ cache_root, "clitest" }); + defer fs.cwd().deleteTree(dir_path) catch {}; + const TestFn = fn ([]const u8, []const u8) anyerror!void; const test_fns = [_]TestFn{ testZigInitLib, @@ -174,4 +176,13 @@ fn testZigFmt(zig_exe: []const u8, dir_path: []const u8) !void { const run_result3 = try exec(dir_path, true, &[_][]const u8{ zig_exe, "fmt", dir_path }); // both files have been formatted, nothing should change now testing.expect(run_result3.stdout.len == 0); + + // Check UTF-16 decoding + const fmt4_zig_path = try fs.path.join(a, &[_][]const u8{ dir_path, "fmt4.zig" }); + var unformatted_code_utf16 = "\xff\xfe \x00 \x00 \x00 \x00/\x00/\x00 \x00n\x00o\x00 \x00r\x00e\x00a\x00s\x00o\x00n\x00"; + try fs.cwd().writeFile(fmt4_zig_path, unformatted_code_utf16); + + const run_result4 = try exec(dir_path, true, &[_][]const u8{ zig_exe, "fmt", dir_path }); + testing.expect(std.mem.startsWith(u8, run_result4.stdout, fmt4_zig_path)); + testing.expect(run_result4.stdout.len == fmt4_zig_path.len + 1 and run_result4.stdout[run_result4.stdout.len - 1] == '\n'); }