From 2f72f47191543e2b113ed8c0918c232135f813a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:40:23 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20Fix=20chat=20for=20windows=20(#2?= =?UTF-8?q?443)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix chat for windows * add some tests back * Revert "add some tests back" This reverts commit 350aef52f53f8cf34fccd7ad0f78a3dd63867e06. --- examples/scripts/chat.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/scripts/chat.py b/examples/scripts/chat.py index 34f4ecd0ce..12e7c448d4 100644 --- a/examples/scripts/chat.py +++ b/examples/scripts/chat.py @@ -15,7 +15,7 @@ import copy import json import os -import pwd +import platform import re import sys import time @@ -32,6 +32,10 @@ from trl.trainer.utils import get_quantization_config +if platform.system() != "Windows": + import pwd + + init_zero_verbose() HELP_STRING = """\ @@ -138,7 +142,10 @@ def print_help(self): def get_username(): - return pwd.getpwuid(os.getuid())[0] + if platform.system() == "Windows": + return os.getlogin() + else: + return pwd.getpwuid(os.getuid()).pw_name def create_default_filename(model_name):