Skip to content

Commit

Permalink
fix: [autoComplete] optimized complete logic
Browse files Browse the repository at this point in the history
Log: as title
  • Loading branch information
LiHua000 authored and deepin-mozart committed Sep 20, 2024
1 parent 859f569 commit a2aaaba
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 26 deletions.
30 changes: 18 additions & 12 deletions src/plugins/codegeex/codegeex/copilotapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ CopilotApi::CopilotApi(QObject *parent)
{
}

void CopilotApi::postGenerate(const QString &url, const QString &code, const QString &suffix)
void CopilotApi::postGenerate(const QString &url, const QString &prefix, const QString &suffix, GenerateType type)
{
if (completionReply)
completionReply->close();

QByteArray body = assembleGenerateBody(code, suffix);
QByteArray body = assembleGenerateBody(prefix, suffix, type);
QNetworkReply *reply = postMessage(url, CodeGeeXManager::instance()->getSessionId(), body);
completionReply = reply;
reply->setProperty("responseType", CopilotApi::inline_completions);
Expand Down Expand Up @@ -99,7 +98,7 @@ QNetworkReply *CopilotApi::postMessage(const QString &url,
"lang":
}
*/
QByteArray CopilotApi::assembleGenerateBody(const QString &prefix, const QString &suffix)
QByteArray CopilotApi::assembleGenerateBody(const QString &prefix, const QString &suffix, GenerateType type)
{
auto file = getCurrentFileInfo();

Expand All @@ -120,7 +119,10 @@ QByteArray CopilotApi::assembleGenerateBody(const QString &prefix, const QString
json.insert("context", context);
json.insert("model", completionModel);
json.insert("lang", file.second);
json.insert("max_new_tokens", 128);
if (type == GenerateType::Line)
json.insert("max_new_tokens", 64);
else
json.insert("max_new_tokens", 128);

QJsonDocument doc(json);
return doc.toJson();
Expand Down Expand Up @@ -189,15 +191,19 @@ void CopilotApi::slotReadReply(QNetworkReply *reply)
if (type == CopilotApi::inline_completions) {
auto content = jsonObject.value("inline_completions").toArray().at(0).toObject();
code = content.value("text").toString();
// Cut the first code segment
auto codeLines = code.split('\n');
QString lastLine = codeLines.last();
if (content.value("finish_reason").toString() == "length") {
// Due to the length limit of the code, the last line will be discarded when the code is truncated.
auto codeLines = code.split('\n');
if (codeLines.size() > 1)
codeLines.removeLast();
code = codeLines.join('\n');
}

QRegularExpression endOfLinePattern("\\n|;|}");
if (!endOfLinePattern.match(lastLine).hasMatch())
codeLines.removeLast();
code = codeLines.mid(0, codeLines.indexOf("", 1)).join('\n') + '\n';
completionReply = nullptr;

// all '\n'
if (code.split('\n', QString::SkipEmptyParts).isEmpty())
return;
emit response(CopilotApi::inline_completions, code, "");
} else if (type == CopilotApi::multilingual_code_translate) {
auto codeLines = jsonObject.value("text").toString().split('\n');
Expand Down
10 changes: 8 additions & 2 deletions src/plugins/codegeex/codegeex/copilotapi.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,15 @@ class CopilotApi : public QObject
Q_OBJECT

public:
enum GenerateType {
Line,
Block
};

CopilotApi(QObject *parent = nullptr);
void setModel(languageModel model);

void postGenerate(const QString &url, const QString &code, const QString &suffix);
void postGenerate(const QString &url, const QString &prefix, const QString &suffix, GenerateType type);

void postComment(const QString &url,
const QString &code,
Expand Down Expand Up @@ -157,7 +162,8 @@ public slots:
QNetworkReply *postMessage(const QString &url, const QString &token, const QByteArray &body);

QByteArray assembleGenerateBody(const QString &prefix,
const QString &suffix);
const QString &suffix,
GenerateType type);

QByteArray assembleTranslateBody(const QString &code,
const QString &dst_lang,
Expand Down
88 changes: 78 additions & 10 deletions src/plugins/codegeex/copilot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ static const char *commandCommits = "commit_message";

using namespace CodeGeeX;
using namespace dpfservice;

Copilot::Copilot(QObject *parent)
: QObject(parent)
{
Expand All @@ -36,12 +37,24 @@ Copilot::Copilot(QObject *parent)
replaceSelectedText(response);
break;
case CopilotApi::inline_completions:
mutexResponse.lock();
generateResponse = response;
if (editorService->setCompletion && responseValid(response)) {
editorService->setCompletion(generateResponse, QIcon::fromTheme("codegeex_anwser_icon"), QKeySequence(Qt::CTRL | Qt::Key_T));
if (!responseValid(response))
return;
{
QString completion = "";

if (generateType == CopilotApi::Line) {
generateCache = response.split('\n');
completion = extractSingleLine();
} else if (generateType == CopilotApi::Block) {
generateCache.clear();
completion = response;
}

if (editorService->setCompletion) {
editorService->setCompletion(completion, QIcon::fromTheme("codegeex_anwser_icon"), QKeySequence(Qt::CTRL | Qt::Key_T));
generatedCode = completion;
}
}
mutexResponse.unlock();
break;
case CopilotApi::multilingual_code_translate:
emit translatedResult(response, dstLang);
Expand Down Expand Up @@ -170,12 +183,21 @@ void Copilot::generateCode()
if (!generateCodeEnabled)
return;

QString prompt = editorService->getCursorBeforeText();
QString prefix = editorService->getCursorBeforeText();
QString suffix = editorService->getCursorBehindText();

copilotApi.postGenerate(kUrlGenerateMultiLine,
prompt,
suffix);
if (!prefix.endsWith(generatedCode) || generateCache.isEmpty()) {
generateType = checkPrefixType(prefix);
copilotApi.postGenerate(kUrlGenerateMultiLine,
prefix,
suffix,
generateType);
} else {
QString completion = extractSingleLine();
if (editorService->setCompletion) {
editorService->setCompletion(completion, QIcon::fromTheme("codegeex_anwser_icon"), QKeySequence(Qt::CTRL | Qt::Key_T));
generatedCode = completion;
}
}
}

void Copilot::login()
Expand Down Expand Up @@ -265,3 +287,49 @@ QString Copilot::assembleCodeByCurrentFile(const QString &code)
result = "```" + fileType + "\n" + code + "```";
return result;
}

CodeGeeX::CopilotApi::GenerateType Copilot::checkPrefixType(const QString &prefixCode)
{
//todo
Q_UNUSED(prefixCode)
if (0)
return CopilotApi::Line;
else
return CopilotApi::Block;
}

QString Copilot::extractSingleLine()
{
if (generateCache.isEmpty())
return "";

bool extractedCode = false;
QString completion = "";
for (auto line : generateCache) {
if (extractedCode)
break;
if (line != "")
extractedCode = true;

completion += line == "" ? "\n" : line;
generateCache.removeFirst();
}
completion += "\n";

//check if left cache all '\n'
bool leftAllEmpty = true;
for (auto line : generateCache) {
if (line == "")
continue;
leftAllEmpty = false;
break;
}
if (leftAllEmpty) {
generateCache.clear();
completion += "\n";
}

if (!extractedCode)
completion = "";
return completion;
}
9 changes: 7 additions & 2 deletions src/plugins/codegeex/copilot.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@ public slots:

CodeGeeX::CopilotApi copilotApi;
dpfservice::EditorService *editorService = nullptr;
QString generateResponse;
QMutex mutexResponse;
QStringList generateCache {};
QString generatedCode {};
QString extractSingleLine();

CodeGeeX::CopilotApi::GenerateType generateType;
CodeGeeX::CopilotApi::GenerateType checkPrefixType(const QString &prefixCode);

bool generateCodeEnabled = true;
};

Expand Down

0 comments on commit a2aaaba

Please sign in to comment.