-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindex.js
225 lines (195 loc) · 7.05 KB
/
index.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import express from 'express';
import bodyParser from 'body-parser';
import Replicate from 'replicate';
import dotenv from 'dotenv';
import fs from 'fs';
import { UMAP } from 'umap-js';
dotenv.config();
const embeddings = JSON.parse(fs.readFileSync('embeddings.json', 'utf-8'));
// console.log('Embeddings loaded.');
const replicate = new Replicate({ auth: process.env.REPLICATE_API_TOKEN });
const chatModel = 'meta/llama-2-7b-chat';
const chatVersion = '8e6975e5ed6174911a6ff3d60540dfd4844201974602551e10e9e87ab143d81e';
const searchModel = 'nateraw/bge-large-en-v1.5';
const searchVersion = '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1';
const app = express();
app.use(bodyParser.json());
app.use(express.static('public'));
// Function to generate query with LLaMA model
async function generate(history) {
let formattedHistory = '';
for (let i = 0; i < history.length; i++) {
if (history[i].role === 'user') {
formattedHistory += `[INST] ${history[i].content} How are you? [/INST]\n`;
} else {
formattedHistory += `${history[i].content}\n`;
}
}
fs.appendFile('data.txt', `\n${history[history.length - 1].content}`, function (err) {
if (err) {
console.error(err);
}
});
console.log(`Generating response to: ${history[history.length - 1].content}`);
if (formattedHistory.endsWith('\n')) {
formattedHistory = formattedHistory.slice(0, -1);
}
let raw = fs.readFileSync('data.txt', 'utf8');
let lines = raw.split(/[\n\r]+/);
let markov = new MarkovGeneratorWord(1, 280);
for (let i = 0; i < lines.length; i++) {
markov.feed(lines[i]);
}
let result = markov.generateMarkov();
result = result.replace('\n', '<br/><br/>');
console.log(`Generated Markov: ${result}`);
const input = {
prompt: formattedHistory,
temperature: 0.1,
// system_prompt: `Someone asked you, “how are you?” Respond as if you feel like this: "${result}". But no matter what, do not ask any questions in your response.`,
system_prompt: `Make this, "${result}", comprehensible with proper grammar and respond to "how are you?". No matter what, do not ask any questions in your response. Reply only with the formatted answer.`,
};
// console.log(`Sending to LLaMA: ${input} `);
console.log(input);
const output = await replicate.run(`${chatModel}:${chatVersion}`, { input });
console.log(`Reformatting with LLaMA: ${output.join('').trim()}`);
const response = [];
response.push(output.join('').trim());
response.push(result);
return response;
}
// Function to get embedding for a given text
async function getEmbedding(text) {
console.log(`Generating embedding for: "${text}"`);
const input = {
texts: JSON.stringify([text]),
batch_size: 32,
convert_to_numpy: false,
normalize_embeddings: true,
};
const output = await replicate.run(`${searchModel}:${searchVersion}`, { input });
return output[0];
}
// Function to find similar texts based on cosine similarity
async function findSimilar(prompt) {
// console.log('Finding similar responses to: ' + prompt);
const inputEmbedding = await getEmbedding(prompt);
// Calculate similarity of each embedding with the input
let similarities = embeddings.map(({ text, embedding }) => ({
text,
similarity: cosineSimilarity(inputEmbedding, embedding),
}));
// Sort similarities in descending order
similarities = similarities.sort((a, b) => b.similarity - a.similarity);
console.log(
`Similarities found: \n1: ${similarities[0].text}, Score: ${similarities[0].similarity.toFixed(3)}\n2: ${
similarities[1].text
}, Score: ${similarities[1].similarity.toFixed(3)}\n3: ${
similarities[2].text
}, Score: ${similarities[2].similarity.toFixed(3)}`
);
return similarities;
}
async function clustering() {
let embeddingArr = [];
for (let i = 0; i < embeddings.length; i++) {
embeddingArr.push(embeddings[i].embedding);
}
// console.log(embeddingArr);
let umap = new UMAP({ nNeighbors: 15, minDist: 0.1, nComponents: 2 });
let umapResults = umap.fit(embeddingArr);
// console.log(umapResults);
return umapResults;
}
//Endpoint to converse with LLaMA
app.post('/api/chat', async (req, res) => {
const conversationHistory = req.body.history;
try {
const modelReply = await generate(conversationHistory);
res.json({ reply: modelReply });
} catch (error) {
console.error('[/api/chat] Error communicating with Replicate API:', error);
res.status(500).send('Error generating response');
}
});
// Endpoint to find similar texts based on embeddings
app.post('/api/similar', async (request, response) => {
let prompt = request.body.prompt;
console.log('Searching for similar responses to: ' + prompt);
let n = request.body.n || 10;
try {
let similarities = await findSimilar(prompt);
similarities = similarities.slice(0, n);
response.json(similarities);
} catch (error) {
console.error('[/api/similar] Error communicating with Replicate API:', error);
response.status(500).send('Error generating response');
}
});
app.post('/api/cluster', async (request, response) => {
// let prompt = request.body.prompt;
// console.log('Sending Embeddings...');
try {
let umapResults = await clustering();
let umapEmbeds = { umapResults, embeddings };
response.json(umapEmbeds);
} catch (error) {
console.error('[/api/cluster] Error communicating with Replicate API:', error);
response.status(500).send('Error generating response');
}
});
const PORT = process.env.PORT || 3003;
app.listen(PORT, () => {
console.log(`Server is running on http://localhost:${PORT}`);
});
String.prototype.tokenize = function () {
return this.split(/\s+/);
};
class MarkovGeneratorWord {
constructor(n, max) {
this.n = n;
this.max = max;
this.ngrams = {};
this.beginnings = [];
}
feed(text) {
var tokens = text.tokenize();
if (tokens.length < this.n) {
return false;
}
var beginning = tokens.slice(0, this.n).join(' ');
this.beginnings.push(beginning);
for (var i = 0; i < tokens.length - this.n; i++) {
let gram = tokens.slice(i, i + this.n).join(' ');
let next = tokens[i + this.n];
if (!this.ngrams[gram]) {
this.ngrams[gram] = [];
}
this.ngrams[gram].push(next);
}
}
generateMarkov() {
let current = this.beginnings[Math.floor(Math.random() * this.beginnings.length)];
let output = current.tokenize();
for (let i = 0; i < this.max; i++) {
if (this.ngrams[current]) {
let possible_next = this.ngrams[current];
let next = possible_next[Math.floor(Math.random() * possible_next.length)];
output.push(next);
current = output.slice(output.length - this.n, output.length).join(' ');
} else {
break;
}
}
return output.join(' ');
}
}
function dotProduct(vecA, vecB) {
return vecA.reduce((sum, val, i) => sum + val * vecB[i], 0);
}
function magnitude(vec) {
return Math.sqrt(vec.reduce((sum, val) => sum + val * val, 0));
}
function cosineSimilarity(vecA, vecB) {
return dotProduct(vecA, vecB) / (magnitude(vecA) * magnitude(vecB));
}