-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
396 lines (327 loc) · 12.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
# Standard library imports
import json
import os
import pickle
import re
import threading
import time
import traceback
import itertools
from collections import defaultdict
# Third-party imports
from mido import MidiFile
# Local application/library specific imports
import fluidsynth_lib
from client import AbletonOSCClient
from generate_midi import make_midi, modify_midi, midifile_to_notes
# Constants
IP = "127.0.0.1"
PORT = 11000
OPERATOR_MAP_PATH = "./params/OperatorMap.json"
MIDI_MAP_PATH = "./params/MidiMap.json"
# Load operator params (mapping generated by chatgpt)
with open(OPERATOR_MAP_PATH, "r") as f:
operator_map = {int(k): v for k, v in json.load(f).items()}
# Load general midi map
with open(MIDI_MAP_PATH, "r") as f:
midi_map = {int(k): v for k, v in json.load(f).items()}
# Initialize Ableton OSC client
client = AbletonOSCClient(IP, PORT)
try:
assert client.query("/live/test")[0] == "ok"
except:
print(
"No response from Ableton. You need to copy client.py into Remote Scripts, "
"and enable it as a control surface in preferences."
)
class Clip:
def get_name(x, y):
"""Get the name of the clip at the given position.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
Returns:
str: The name of the clip.
"""
x, y, name = client.query("/live/clip/get/name", (x, y))
return name
def get_curr_name():
"""Get the name of the currently selected clip.
Returns:
tuple: The x-coordinate, y-coordinate, and name of the clip.
"""
try:
x, y = client.query("/live/view/get/selected_clip")
except RuntimeError:
# print('get_curr_name: no selected clip')
return None, None, None
try:
name = Clip.get_name(x, y)
return x, y, name
except RuntimeError:
# print('get_curr_name: no name for selected clip')
return None, None, None
def set_name(x, y, name):
"""Set the name of the clip at the given position.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
name (str): The new name of the clip.
"""
client.send_message("/live/clip/set/name", (x, y, name))
def create(x, y, length):
"""Create a new clip at the given position with the given length.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
length (int): The length of the clip.
"""
client.send_message("/live/clip_slot/create_clip", (x, y, length))
def remove_notes(x, y):
"""Remove all notes from the clip at the given position.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
"""
client.send_message("/live/clip/remove/notes", (x, y, 0, 127, 0, 100000))
def get_notes(x, y):
"""Get all notes from the clip at the given position.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
Returns:
list: The notes in the clip.
"""
try:
notes = client.query("/live/clip/get/notes", (x, y, 0, 127, 0, 100000))[2:]
print("notes:", notes)
return notes
except Exception as e:
print("failed to get notes", e)
return None
def set_loop_points(x, y, start, end):
"""Set the loop points of the clip at the given position.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
start (int): The start point of the loop.
end (int): The end point of the loop.
"""
client.send_message("/live/clip/set/loop_start", (x, y, start))
client.send_message("/live/clip/set/loop_end", (x, y, end))
def insert_clip(x, y, midifilename, prompt):
"""Insert a clip at the given position.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
midifilename (str): The filename of the MIDI file to insert.
prompt (str): The prompt to use when inserting the clip.
"""
try:
class_name = client.query("/live/device/get/class_name", (x, 0))[2]
except Exception:
class_name = None
Clip.remove_notes(x, y)
midifile = MidiFile(midifilename)
full_notes, length, instrument = midifile_to_notes(midifile)
Clip.create(x, y, length)
Clip.set_name(x, y, "AI:" + prompt + ("empty" if len(full_notes) == 0 else ""))
for note, velocity, start, duration in full_notes:
client.send_message(
"/live/clip/add/notes",
(x, y, note, start, duration, velocity, 0),
# last param is 'mute'
)
Clip.set_loop_points(x, y, 0, length)
if instrument is None:
print("no instrument found")
return
if class_name == "Operator":
print("Setting Operator instrument to", instrument)
Clip.set_instrument(x, y, instrument)
else:
channel = Clip.get_output_channel(x)
if channel is not None:
fluidsynth_lib.set_instrument(channel, instrument)
try:
name = midi_map[instrument]
client.send_message("/live/track/set/name", (x, name))
except Exception as e:
print("failed to set instrument name", e)
traceback.print_exc()
def get_output_channel(x):
"""Get the output channel of the track at the given position.
Args:
x (int): The x-coordinate of the track.
Returns:
int: The output channel of the track.
"""
try:
chan = client.query("/live/track/get/output_routing_channel", (x,))[1]
print(chan)
chan = chan.rsplit(" ", 1)[-1]
return int(chan) - 1
except Exception as e:
print("Failed to set GM-MIDI instrument. Couldn't get output channel:", e)
return None
def set_instrument(x, y, instrument):
"""Set the instrument of the clip at the given position.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
instrument (str): The instrument to set.
"""
try:
print("setting instrument to", instrument)
filename = operator_map[instrument]
print("OP Patch:", filename)
names, values = pickle.load(open(f"./params/Operator/{filename}.pkl", "rb"))
client.send_message("/live/device/set/parameters/value", (x, 0, *values))
client.send_message("/live/track/set/name", (x, filename))
except Exception as e:
print()
print("failed to set instrument", e)
# print traceback
traceback.print_exc()
def is_midi_track(x):
"""Check if the track at the given position is a MIDI track.
Args:
x (int): The x-coordinate of the track.
Returns:
bool: True if the track is a MIDI track, False otherwise.
"""
try:
return client.query("/live/track/get/has_midi_input", (x,))[1]
except:
return None
# x,y -> prompt
WAIT_LIST = {}
SPINNER_GRID = defaultdict(lambda: itertools.cycle(["-", "/", "|", "\\"]))
class Gen:
def add_prompt(x, y, prompt):
"""Add a prompt to the wait list.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
prompt (str): The prompt to add.
"""
filename = Gen.get_filename(x, y)
if os.path.exists(filename + ".midi"):
os.remove(filename + ".midi")
WAIT_LIST[(x, y)] = prompt
def wait_list():
"""Get the wait list.
Returns:
list: The wait list.
"""
files = []
for (x, y), prompt in WAIT_LIST.items():
files.append((x, y, Gen.get_filename(x, y)))
return files
def finished(x, y):
"""Remove a prompt from the wait list.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
Returns:
str: The removed prompt.
"""
prompt = WAIT_LIST[(x, y)]
del WAIT_LIST[(x, y)]
return prompt
def get_filename(x, y):
"""Get the filename of the clip at the given position.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
Returns:
str: The filename of the clip.
"""
gen_dir = os.path.join(os.path.dirname(__file__), "gens")
return os.path.join(gen_dir, f"{x}-{y}")
# async generate MIDI using GPT
def start_thread_prompt(x, y, prompt):
"""Start a new thread to generate a MIDI file from a prompt.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
prompt (str): The prompt to generate the MIDI file from.
"""
the_thread = threading.Thread(
target=make_midi, args=(prompt, Gen.get_filename(x, y))
)
the_thread.start()
def start_thread_modify(x, y, prompt, existing_abc):
"""Start a new thread to modify a MIDI file from a prompt.
Args:
x (int): The x-coordinate of the clip.
y (int): The y-coordinate of the clip.
prompt (str): The prompt to modify the MIDI file from.
existing_abc (str): The existing ABC notation of the MIDI file.
"""
the_thread = threading.Thread(
target=modify_midi, args=(prompt, existing_abc, Gen.get_filename(x, y))
)
the_thread.start()
def extract_abc_title(abc):
"""Extract the title from ABC notation.
Args:
abc (str): The ABC notation to extract the title from.
Returns:
str: The title of the ABC notation.
"""
title = re.search(r"^T:(.*)$", abc, re.MULTILINE).group(1)
return title
def event_loop():
"""Run the event loop.
This function checks the wait list for finished tasks and handles them.
"""
files = Gen.wait_list()
for x, y, filename in files:
if os.path.exists(filename + ".midi"):
prompt = Gen.finished(x, y)
print("finished generating", filename)
print("inserting clip name =", prompt)
with open(filename + ".abc", "r") as f:
f.read()
Clip.insert_clip(x, y, filename + ".midi", prompt)
else:
Clip.set_name(x, y, SPINNER_GRID[(x, y)].__next__())
x, y, name = Clip.get_curr_name()
if Clip.is_midi_track(x) is False:
print("not a midi track")
return
if name is None:
return
if name.startswith("AI"):
# print(f'clip name: "{name}" (already starts with "AI:")')
pass
elif len(name) <= 1:
# print('spinner or empty clip name')
pass
else:
print(f'AI needed clip found: "{name}"')
prompt = name
if os.path.exists(Gen.get_filename(x, y) + ".abc"):
# ADD CHECK FOR THERE's MIDI NOTES IN CURR CLIP
try:
notes = len(Clip.get_notes(x, y)) > 0
except Exception as e:
print("error getting notes:", e)
notes = False
if notes:
print("already generated", Gen.get_filename(x, y))
Gen.add_prompt(x, y, prompt)
with open(Gen.get_filename(x, y) + ".abc", "r") as f:
existing_abc = f.read()
start_thread_modify(x, y, prompt, existing_abc)
else:
Gen.add_prompt(x, y, prompt)
start_thread_prompt(x, y, prompt)
else:
Gen.add_prompt(x, y, prompt)
start_thread_prompt(x, y, prompt)
print("waiting for named MIDI clip to appear..")
while True:
event_loop()
time.sleep(0.1)