diff --git a/apps/cli/src/git.zig b/apps/cli/src/git.zig index 5999590..eafea59 100644 --- a/apps/cli/src/git.zig +++ b/apps/cli/src/git.zig @@ -7,6 +7,11 @@ pub fn getStagedChanges(allocator: std.mem.Allocator) ![]u8 { "--cached", } }); + if (result.stdout.len == 0) { + try std.io.getStdOut().writer().print("No changes to commit\n", .{}); + std.process.exit(0); + } + return result.stdout; } diff --git a/apps/cli/src/main.zig b/apps/cli/src/main.zig index 4ecc3a6..287fdd0 100644 --- a/apps/cli/src/main.zig +++ b/apps/cli/src/main.zig @@ -8,7 +8,7 @@ var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator); const allocator = arena.allocator(); var config = struct { - model: []const u8 = "gpt-4-turbo-preview", + model: []const u8 = openai.Model.Default.toString(), }{}; var model = cli.Option{ @@ -19,11 +19,6 @@ var model = cli.Option{ fn generate_handler() !void { const diff = try git.getStagedChanges(allocator); - if (diff.len == 0) { - try std.io.getStdOut().writer().print("No changes to commit\n", .{}); - return; - } - const response = try openai.getCompletion(allocator, diff, config.model); try std.io.getStdOut().writer().print("{s}\n", .{response.choices[0].message.content}); @@ -31,11 +26,6 @@ fn generate_handler() !void { fn commit_handler() !void { const diff = try git.getStagedChanges(allocator); - if (diff.len == 0) { - try std.io.getStdOut().writer().print("No changes to commit\n", .{}); - return; - } - const response = try openai.getCompletion(allocator, diff, config.model); const message = response.choices[0].message.content; diff --git a/apps/cli/src/openai.zig b/apps/cli/src/openai.zig index b2371d8..d56d8c3 100644 --- a/apps/cli/src/openai.zig +++ b/apps/cli/src/openai.zig @@ -20,6 +20,24 @@ const CompletionResponse = struct { choices: []Choice, }; +pub const Model = enum(u2) { + Fast, + Default, + + /// Returns the string representation of the enum value + pub fn toString(self: Model) []const u8 { + return switch (self) { + .Fast => "gpt-3.5-turbo", + .Default => "gpt-4-turbo-preview", + }; + } + + /// Checks whether the provided model is supported + pub fn isSupported(model: []const u8) bool { + return std.mem.eql(u8, model, Model.Fast.toString()) or std.mem.eql(u8, model, Model.Default.toString()); + } +}; + const SYSTEM_MESSAGE = \\ You are a helpful coding assistant responsible for generating fitting commit messages. \\ You will be provided a git diff or code snippet and you are expected to provide a suitable commit message. @@ -48,6 +66,11 @@ const SYSTEM_MESSAGE = ; pub fn getCompletion(allocator: std.mem.Allocator, prompt: []const u8, model: []const u8) !CompletionResponse { + if (!Model.isSupported(model)) { + std.debug.print("Unsupported model provided. If you believe the model should be supported, report an issue at https://github.com/segersniels/genmoji/issues.\n", .{}); + std.process.exit(1); + } + var env = try std.process.getEnvMap(allocator); defer env.deinit();