Skip to content

Commit

Permalink
Merge pull request #23 from Aquila-Network/to_async
Browse files Browse the repository at this point in the history
To async, disk persist, bug fixes
  • Loading branch information
freakeinstein authored Jan 17, 2022
2 parents 7d9639e + 8edfc70 commit 514ad9d
Show file tree
Hide file tree
Showing 11 changed files with 569 additions and 309 deletions.
2 changes: 1 addition & 1 deletion src/config.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
auth:
pubkey: "/ossl/public.pem"
ipfs:
api: "http://127.0.0.1:5001/api"
gateway: "http://127.0.0.1:8080"
206 changes: 206 additions & 0 deletions src/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import logging

import fasttext
from utils import downloader
import hashlib
import base58
import json

from sentence_transformers import SentenceTransformer

import os

# define constants
MODEL_FASTTEXT = "ftxt"
MODEL_S_TRANSFORMER = "strn"
PREDICT_BATCH_SIZE = 1000

# Maintain a model directory
data_dir = os.environ["DATA_STORE_LOCATION"]
model_dir = data_dir + "models/"
model_dict = {}

def get_url (schema):
"""
Get model url from a schema
"""

if schema.get("encoder") != None:
return schema["encoder"]
else:
return None

def get_url_hash (url):
hash_ = hashlib.sha256(url.encode('utf-8'))
b58c_ = base58.b58encode(hash_.digest())
return b58c_.decode('utf-8')

def download_model (url, directory, file_name):
"""
Download a model from a URL
"""

# handle fasttext models from url or IPFS
if url.split(":")[0] == MODEL_FASTTEXT:
url = ":".join(url.split(":")[1:])

if url.split(":")[0] == "http" or url.split(":")[0] == "https":
return MODEL_FASTTEXT, downloader.http_download(url, directory, file_name+".bin")

elif url.split(":")[0] == "ipfs":
return MODEL_FASTTEXT, downloader.ipfs_download(url, directory, file_name+".bin")
elif url.split(":")[0] == MODEL_S_TRANSFORMER:
url = ":".join(url.split(":")[1:])
return MODEL_S_TRANSFORMER, url
else:
logging.error("Invalid encoder specified in schema definition.")
return None, ""

def memload_model (model_type, model_filename):
"""
Load a model from disk
"""

if model_type == MODEL_FASTTEXT:
if model_filename:
logging.debug("loading fasttext model into memory..")
return model_type, fasttext.load_model(model_filename)
else:
return None, None
elif model_type == MODEL_S_TRANSFORMER:
if model_filename:
logging.debug("loading STransformer model into memory..")
return model_type, SentenceTransformer(model_filename)
else:
return None, None
else:
return None, None

class EncodeRequest ():
def __init__(self, id_in, text_in):
self.id = id_in
self.text = text_in

class Encoder ():
def __init__(self, encoder_name_in, request_queue_in):
self.encoder_name = get_url_hash(encoder_name_in)
# to handle requests
self.request_queue = request_queue_in
self.request_id_counter = 0
self.request_id_counter_max = 10000
# to handle responses
self.response_queue = [None] * self.request_id_counter_max

def __del__(self):
logging.debug("killed encoder for database")

def count_request_id (self):
ret_ = self.request_id_counter
self.request_id_counter = (self.request_id_counter + 1) % self.request_id_counter_max
return ret_

def preload_model (self, json_schema, database_name):
"""
Download a model and load it into memory
"""

# prefill model & hash dictionary
global model_dict

try:
# load model if not done already
if not model_dict.get(self.encoder_name):
model_type_, model_file_loc_ = download_model(get_url(json_schema), model_dir, self.encoder_name)
# download success
if model_file_loc_ != None:
model_dict[self.encoder_name] = {}
# load into memory
model_dict[self.encoder_name]["type"], model_dict[self.encoder_name]["model"] = memload_model(model_type_, model_file_loc_)
# memory loading failed
if model_dict[self.encoder_name]["type"] == None:
logging.error("Memory loading of model failed")
return False
else:
return False

if model_dict[self.encoder_name].get("model"):
logging.debug("Model loaded for database: "+database_name)
return True
else:
logging.error("Model loading failed for database: "+database_name)
# reset DB - hash map
del model_dict[self.encoder_name]
return False
else:
return True

except Exception as e:
logging.error(e)
return False

async def enqueue_compress_data (self, texts):
"""
Add to request queue for compression
"""
request_ = EncodeRequest(self.count_request_id(), texts)

await self.request_queue.put(request_)

return request_.id

async def process_queue (self):
"""
Load an already existing model, pop request queue,
compress information, push to response queue
"""

request_data = []
request_metadata = []
max_batch_len = PREDICT_BATCH_SIZE # model's batching capacity
# create batch from req. queue
while(not self.request_queue.empty()):
# get an item from queue
section_ = await self.request_queue.get()
request_data += section_.text
request_metadata.append( (section_.id, len(section_.text)) )
# check max. batch length achieved
if len(request_data) > max_batch_len:
break


# prefill model & hash dictionary
global model_dict

# model_dict[self.encoder_name]
if not model_dict.get(self.encoder_name):
# try dynamic loading of model
try:
model_dict[self.encoder_name] = {}
model_dict[self.encoder_name]["type"], model_dict[self.encoder_name]["model"] = memload_model(MODEL_FASTTEXT, model_dir + self.encoder_name + ".bin")
except Exception as e:
logging.error("Model not pre-loaded for database.")
logging.error(e)
return []

result = []
try:
# fasttext model prediction
if model_dict[self.encoder_name]["type"] == MODEL_FASTTEXT:
result = []
# fasttext doesn't take in batch; so, loop it.
for line_ in request_data:
result.append(model_dict[self.encoder_name]["model"].get_sentence_vector(line_).tolist())
# stransformer model prediction
if model_dict[self.encoder_name]["type"] == MODEL_S_TRANSFORMER:
result = model_dict[self.encoder_name]["model"].encode(request_data).tolist()

except Exception as e:
logging.error(e)
logging.error("Model prediction error for database.")

# add results to response queue
self.response_queue[request_metadata[0][0]] = result[0:request_metadata[0][1]]
old_metadata = request_metadata[0]
for metadata_ in request_metadata[1:]:
self.response_queue[metadata_[0]] = result[old_metadata[1]:old_metadata[1]+metadata_[1]]
old_metadata = metadata_
87 changes: 53 additions & 34 deletions src/index.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,51 @@
import logging

from flask import Flask, request
from flask_cors import CORS
from flask import jsonify
from quart import Quart
from quart import request

from functools import wraps
import asyncio

from utils import config
import authentication
import router

import time
from multiprocessing import Process
import manager as man_

app = Flask(__name__, instance_relative_config=True)
app = Quart(__name__, instance_relative_config=True)

# Server starter
def flaskserver ():
def quartserver ():
"""
start server
"""
app.run(host='0.0.0.0', port=5002, debug=False)

server = Process(target=flaskserver)

# Enable CORS
CORS(app)

# Add authentication
def authenticate ():
def decorator (f):
@wraps(f)
def wrapper (*args, **kwargs):
params = extract_request_params(request)
async def wrapper (*args, **kwargs):
params = await extract_request_params(request)

if not params or not "data" in params or not "signature" in params:
return "Unauthorised access", 401

if not authentication.check(params["data"], params["signature"]):
return "Unauthorised access", 401

return f(*args, **kwargs)
return await f(*args, **kwargs)

return wrapper
return decorator

def extract_request_params (request):
async def extract_request_params (request):
if not request.is_json:
logging.error("Cannot parse request parameters")

# request is invalid
return {}

# Extract JSON data
data_ = request.get_json()
data_ = await request.get_json()

return data_

Expand All @@ -70,15 +63,13 @@ def info ():

@app.route("/prepare", methods=['POST'])
@authenticate()
def prepare_model ():
async def prepare_model ():
"""
Preload and prepare model from schema definition
"""

# get parameters
params = None
if extract_request_params(request).get("data"):
params = extract_request_params(request)["data"]
params = (await extract_request_params(request)).get("data")

if not params:
# Build error response
Expand All @@ -88,7 +79,7 @@ def prepare_model ():
}, 400

if "schema" in params:
database_name = router.preload_model(params.get("schema"))
database_name = app.manager.preload_model(params.get("schema"))

# Build response
if database_name:
Expand All @@ -108,15 +99,13 @@ def prepare_model ():
}, 400

@app.route("/compress", methods=['POST'])
def compress_data ():
async def compress_data ():
"""
generate embeddings for an input data
"""

# get parameters
params = None
if extract_request_params(request).get("data"):
params = extract_request_params(request)["data"]
params = (await extract_request_params(request)).get("data")

if not params:
# Build error response
Expand All @@ -126,18 +115,48 @@ def compress_data ():
}, 400

if "text" in params and "databaseName" in params:
vectors = router.compress_data(params.get("databaseName"), params.get("text"))
vectors = await app.manager.compress_data(params.get("databaseName"), params.get("text"))

# Build response
return {
"success": True,
"vectors": vectors
}, 200
if vectors:
return {
"success": True,
"vectors": vectors
}, 200
else:
return {
"success": False,
"message": "Database not found"
}, 400
else:
return {
"success": False,
"message": "Invalid parameters"
}, 400

@app.before_serving
async def init_variables():
app.manager = man_.Manager()
# prepare HUB from backup
try:
app.manager.prepare_hub()
except Exception as e:
logging.error("Backup restore failed")
logging.error(e)
# initialize background task controller
app.manager.bg_task_active = True

@app.before_serving
async def init_tasks():
# initialize background task
app.manager.background_task = asyncio.ensure_future(app.manager.background_task())
# app.add_background_task(background_task)

@app.after_serving
async def shutdown():
# shutdown background task
app.manager.bg_task_active = False
app.manager.background_task.cancel()

if __name__ == "__main__":
server.start()
quartserver()
Loading

0 comments on commit 514ad9d

Please sign in to comment.