-
Notifications
You must be signed in to change notification settings - Fork 0
/
storage.py
379 lines (317 loc) · 11.8 KB
/
storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
from typing import List, Optional, Any, Dict
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from aiocache import cached, Cache
import aiosqlite, os, logging, jsonpickle
log = logging.getLogger("storage")
UserId = int
@dataclass
class Criteria:
symbol: Optional[str] = None
page: Optional[int] = None
@dataclass
class UserKey:
uid: UserId
modified: datetime
symbols: Dict[str, datetime] = field(default_factory=dict)
@property
def symbol_key(self) -> str:
values = list(self.symbols.values())
if values:
return str(max(values))
return ""
def __str__(self):
return f"UserKey<{self.uid}, {self.modified}, {self.symbol_key}>"
def __repr__(self):
return str(self)
@dataclass
class SymbolRow:
symbol: str
modified: datetime
earnings: bool
candles: bool
options: bool
data: str
@dataclass
class NoteRow:
symbol: str
ts: datetime
body: str
@dataclass
class UserSymbol:
symbol: SymbolRow
notes: Optional[NoteRow]
@dataclass
class SymbolStorage:
db: Optional[aiosqlite.Connection] = None
async def open(self, name: str = "feevee.db", path: Optional[str] = None):
log.info(f"db:opening")
path = os.path.join(path if path else os.environ["MONEY_CACHE"], name)
self.db = await aiosqlite.connect(path)
await self.db.execute(
"CREATE TABLE IF NOT EXISTS users (id INTEGER PRIMARY KEY AUTOINCREMENT, created DATETIME NOT NULL, modified DATETIME NOT NULL, email TEXT NOT NULL, password TEXT NOT NULL)"
)
await self.db.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS users_email_idx ON users (email)"
)
await self.db.execute(
"CREATE TABLE IF NOT EXISTS user_symbol (user_id INTEGER NOT NULL REFERENCES users(id), symbol TEXT NOT NULL)"
)
await self.db.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS user_symbol_idx ON user_symbol (user_id, symbol)"
)
await self.db.execute(
"CREATE TABLE IF NOT EXISTS user_lot (id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES users(id), symbol TEXT NOT NULL, purchased DATETIME NOT NULL, quantity TEXT NOT NULL, price TEXT NOT NULL)"
)
await self.db.execute(
"CREATE INDEX IF NOT EXISTS user_lot_idx ON user_lot (user_id, symbol)"
)
await self.db.execute(
"CREATE TABLE IF NOT EXISTS user_lot_ledger (id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES users(id), modified DATETIME NOT NULL, body TEXT NOT NULL)"
)
await self.db.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS user_lot_ledger_idx ON user_lot_ledger (user_id)"
)
await self.db.execute(
"CREATE TABLE IF NOT EXISTS symbols (symbol TEXT NOT NULL, modified DATETIME NOT NULL, safe BOOL NOT NULL, earnings BOOL NOT NULL, candles BOOL NOT NULL, options BOOL NOT NULL, data TEXT NOT NULL)"
)
await self.db.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS symbols_symbol_idx ON symbols (symbol)"
)
await self.db.execute(
"CREATE TABLE IF NOT EXISTS notes (id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL REFERENCES users(id), symbol TEXT NOT NULL, modified DATETIME NOT NULL, noted_price DECIMAL(10, 5) NOT NULL, future_price DECIMAL(10, 5), body TEXT NOT NULL)"
)
await self.db.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS notes_idx ON notes (user_id, symbol)"
)
await self.db.commit()
async def close(self):
assert self.db
await self.db.close()
async def get_user_key_by_user_id(self, user_id: int) -> UserKey:
assert self.db
dbc = await self.db.execute(
"""
SELECT u.id, u.modified, MAX(n.modified) AS notes_modified
FROM users AS u LEFT JOIN notes AS n ON (u.id = n.user_id)
WHERE u.id = ?
GROUP BY u.id, u.modified
""",
[user_id],
)
symbol_keys = await self._get_user_symbol_keys(user_id)
found = [
UserKey(row[0], self._parse_datetime(row[1]), symbol_keys)
for row in await dbc.fetchall()
]
assert found
return found[0]
async def get_all_user_ids(self) -> List[UserId]:
assert self.db
dbc = await self.db.execute("SELECT id FROM users")
return [row[0] for row in await dbc.fetchall()]
async def get_all_notes(self, user_key: UserKey) -> Dict[str, NoteRow]:
assert self.db
notes = {}
dbc = await self.db.execute(
"SELECT symbol, modified, noted_price, future_price, body FROM notes WHERE user_id = ? ORDER BY modified DESC",
[user_key.uid],
)
for row in await dbc.fetchall():
symbol = row[0]
if symbol in notes:
continue
note_row = NoteRow(
symbol,
self._parse_datetime(row[1]),
row[4],
)
notes[symbol] = note_row
return notes
async def get_all_symbols(
self, user_key: UserKey, criteria: Criteria
) -> Dict[str, UserSymbol]:
assert self.db
dbc = await self.db.execute(
"""
SELECT s.symbol, s.modified, s.earnings, s.candles, s.options, s.data, notes.modified, notes.body AS notes
FROM symbols AS s
LEFT JOIN notes ON (s.symbol = notes.symbol)
WHERE s.safe AND s.symbol IN (SELECT symbol FROM user_symbol WHERE user_id = ?)
AND (? is NULL OR s.symbol = ?)
ORDER BY s.symbol
""",
[user_key.uid, criteria.symbol, criteria.symbol],
)
rows = await dbc.fetchall()
if criteria.page:
rows = list(rows)
rows = rows[criteria.page * 10 : criteria.page * 10 + 10]
symbols = {
row[0]: UserSymbol(
SymbolRow(
row[0],
self._parse_datetime(row[1]),
row[2],
row[3],
row[4],
row[5],
),
NoteRow(
row[0],
self._parse_datetime(row[6]),
row[7],
)
if row[6]
else None,
)
for row in rows
}
filtering: Optional[List[str]] = None
if "FEEVEE_SYMBOLS" in os.environ:
filtering = os.environ["FEEVEE_SYMBOLS"].split(" ")
return {
key: value
for key, value in symbols.items()
if filtering is None or key in filtering
}
async def add_symbols(self, user_key: UserKey, symbols: List[str]) -> List[str]:
assert self.db
changed: List[str] = []
for symbol in symbols:
# TODO Make data nullable in the future.
await self.db.execute(
"""
INSERT INTO symbols (symbol, modified, candles, earnings, candles, safe, data)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT DO NOTHING
""",
[symbol, datetime.utcnow(), True, False, False, False, "{}"],
)
changes_before = self.db.total_changes
await self.db.execute(
"INSERT INTO user_symbol (user_id, symbol) VALUES (?, ?) ON CONFLICT DO NOTHING",
[user_key.uid, symbol],
)
if changes_before != self.db.total_changes:
log.info(f"{symbol:6} added")
changed.append(symbol)
if len(changed) > 0:
await self._user_modified(user_key)
await self.db.commit()
return changed
async def remove_symbols(self, user_key: UserKey, symbols: List[str]) -> List[str]:
assert self.db
for symbol in symbols:
await self.db.execute(
"DELETE FROM user_symbol WHERE user_id = ? AND symbol = ?",
[user_key.uid, symbol],
)
log.info(f"{symbol:6} removed")
await self._user_modified(user_key)
await self.db.commit()
return []
async def set_symbol(self, symbol: str, safe: bool, data: Dict[str, Any]):
assert self.db
serialized = jsonpickle.encode(data)
log.info(f"{serialized}")
await self.db.execute(
"""
INSERT INTO symbols (symbol, modified, candles, options, earnings, safe, data) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(symbol) DO UPDATE
SET modified = excluded.modified, safe = excluded.safe, data = excluded.data
""",
[symbol, datetime.now(), True, False, False, safe, serialized],
)
await self.db.commit()
async def update_lots(self, user_key: UserKey, lots: str):
assert self.db
await self.db.execute(
"""
INSERT INTO user_lot_ledger (user_id, modified, body) VALUES (?, ?, ?)
ON CONFLICT(user_id) DO UPDATE
SET modified = excluded.modified, body = excluded.body
""",
[
user_key.uid,
datetime.utcnow(),
lots,
],
)
await self._user_modified(user_key)
updated_user_key = await self.get_user_key_by_user_id(user_key.uid)
await self.db.commit()
return updated_user_key
async def get_lots(self, user_key: UserKey) -> str:
assert self.db
dbc = await self.db.execute(
"SELECT user_id, modified, body FROM user_lot_ledger WHERE user_id = ?",
[user_key.uid],
)
for row in await dbc.fetchall():
return row[2]
return ""
async def add_notes(
self,
user_key: UserKey,
symbol: str,
modified: datetime,
noted_price: Decimal,
future_price: Optional[Decimal],
body: str,
):
assert self.db
await self.db.execute(
"""
INSERT INTO notes (user_id, symbol, modified, noted_price, future_price, body)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(user_id, symbol) DO UPDATE
SET modified = excluded.modified, noted_price = excluded.noted_price, future_price = excluded.future_price, body = excluded.body
""",
[
user_key.uid,
symbol,
modified,
str(noted_price),
str(future_price) if future_price else None,
body,
],
)
updated_user_key = await self.get_user_key_by_user_id(user_key.uid)
await self.db.commit()
return updated_user_key
async def _get_user_symbol_keys(self, user_id: int) -> Dict[str, datetime]:
assert self.db
keys = {}
dbc = await self.db.execute(
"""
SELECT symbol, MAX(notes.modified) AS key
FROM notes
WHERE user_id = ?
GROUP BY symbol
""",
[user_id],
)
for row in await dbc.fetchall():
keys[row[0]] = self._parse_datetime(row[1])
return keys
async def _user_modified(self, user_key: UserKey):
assert self.db
await self.db.execute(
"UPDATE users SET modified = ? WHERE id = ?",
[
datetime.utcnow(),
user_key.uid,
],
)
def _parse_datetime(self, value: str) -> datetime:
return datetime.strptime(value, "%Y-%m-%d %H:%M:%S.%f")
db: Optional[SymbolStorage] = None
async def get_db():
global db
if db:
return db
db = SymbolStorage()
log.info(f"db:created")
return db