Skip to content

Commit

Permalink
Fix model re-loding on refresh #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Oct 19, 2023
1 parent 8b50af6 commit ab44aad
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 135 deletions.
4 changes: 2 additions & 2 deletions app.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ title = "SQL-Sidekick"
description = "QnA with tabular data using NLQ"
LongDescription = "about.md"
Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP"]
Version = "0.0.14"
Version = "0.0.16"

[Runtime]
MemoryLimit = "64Gi"
Expand All @@ -13,7 +13,7 @@ module = "start"
VolumeMount = "/meta_data"
VolumeSize = "100Gi"
ResourceVolumeSize = "64Gi"
GPUCount = 1
GPUCount = 2
RuntimeVersion = "ub2004_cuda114_cudnn8_py38_wlatest_a10g"
RoutingMode = "BASE_URL"
EnableOIDC = true
Expand Down
93 changes: 60 additions & 33 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,26 @@
import torch
import torch.nn.functional as F
from langchain import OpenAI
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
from llama_index.indices.struct_store import SQLContextContainerBuilder
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT,
NSQL_QUERY_PROMPT, QUERY_PROMPT,
STARCODER2_PROMPT, TASK_PROMPT)
from sidekick.configs.prompt_template import (
DEBUGGING_PROMPT,
NSQL_QUERY_PROMPT,
QUERY_PROMPT,
STARCODER2_PROMPT,
TASK_PROMPT,
)
from sidekick.logger import logger
from sidekick.utils import (_check_file_info, filter_samples, is_resource_low,
load_causal_lm_model, load_embedding_model,
make_dir, read_sample_pairs, remove_duplicates)
from sidekick.utils import (
_check_file_info,
filter_samples,
is_resource_low,
load_causal_lm_model,
load_embedding_model,
make_dir,
read_sample_pairs,
remove_duplicates,
)
from sqlalchemy import create_engine


Expand All @@ -40,27 +50,38 @@ def __new__(
is_regenerate: bool = False,
is_regenerate_with_options: bool = False,
):
offloading = is_resource_low()
# Initially load one model at a time if the user swapped the model dynamically.
# TODO:
# 1. Keep multiple models in memory if possible
# 2. Support remote model loading as an option
if offloading and is_regenerate_with_options or (cls._instance and cls._instance.model_name != model_name):
del cls._instance
cls._instance = None
# TODO: If openai model then only tokenizer needs to be loaded.
offloading = is_resource_low(model_name)
n_gpus = torch.cuda.device_count()
# Initially load one model at a time if the user swapped the model dynamically when GPU = 1
# If GPU > 1, load multiple models in memory separately on each device.
# TODO
# Support remote model loading as an option
if (
offloading
and is_regenerate_with_options
or (n_gpus == 1 and cls._instance and cls._instance.model_name and cls._instance.model_name != model_name)
):
if cls._instance.models.get(cls._instance.model_name, None):
_name = cls._instance.model_name
del cls._instance.models[_name]
cls._instance.models[_name] = None

gc.collect()
torch.cuda.empty_cache()
logger.info(f"Low memory: {offloading}/ Model re-initialization: True")
logger.info(f"Low memory: {offloading}/ Model re-initialization: {is_regenerate_with_options}")

if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.model, cls._instance.tokenizer = load_causal_lm_model(
if cls._instance is None or (cls._instance and not cls._instance.models.get(model_name, None)):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.models, cls._instance.tokenizers = load_causal_lm_model(
model_name,
cache_path=f"{job_path}/models/",
device=device,
off_load=offloading,
re_generate=is_regenerate_with_options,
)
cls._instance.model_name = "h2ogpt-sql-sqlcoder2" if not model_name else model_name
model_embed_path = f"{job_path}/models/sentence_transformers"
device = "cuda" if torch.cuda.is_available() else "cpu" if device == "auto" else device
cls._instance.similarity_model = load_embedding_model(model_path=model_embed_path, device=device)
Expand Down Expand Up @@ -269,7 +290,11 @@ def generate_tasks(self, table_names: list, input_question: str):
raise se

def generate_sql(
self, table_names: list, input_question: str, _dialect: str = "sqlite", model_name: str = "h2ogpt-sql"
self,
table_names: list,
input_question: str,
_dialect: str = "sqlite",
model_name: str = "h2ogpt-sql-nsql-llama-2-7B",
):
# TODO: Update needed to support multiple tables
table_name = str(table_names[0].replace(" ", "_")).lower()
Expand Down Expand Up @@ -446,8 +471,10 @@ def generate_sql(
)

logger.debug(f"Query Text:\n {query}")
inputs = self.tokenizer([query], return_tensors="pt")
input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1]
tokenizer = self.tokenizers[model_name]
inputs = tokenizer([query], return_tensors="pt")
model = self.models[model_name]
input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
logger.info(f"Context length: {input_length}")

# Handle limited context length
Expand All @@ -467,20 +494,20 @@ def generate_sql(
question_txt=input_question,
)
logger.debug(f"Adjusted query Text:\n {query}")
inputs = self.tokenizer([query], return_tensors="pt")
input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1]
inputs = tokenizer([query], return_tensors="pt")
input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
logger.info(f"Adjusted context length: {input_length}")
# Generate SQL
random_seed = random.randint(0, 50)
torch.manual_seed(random_seed)

# Greedy search for quick response
self.model.eval()
model.eval()
device_type = "cuda" if torch.cuda.is_available() else "cpu"

if not self.is_regenerate_with_options and not self.is_regenerate:
# Greedy decoding
output = self.model.generate(
output = model.generate(
**inputs.to(device_type),
max_new_tokens=300,
temperature=0.5,
Expand All @@ -498,7 +525,7 @@ def generate_sql(
possible_temp_choice = [0.1, 0.2, 0.3, 0.6, 0.75, 0.9, 1.0]
random_temperature = np.random.choice(possible_temp_choice, 1)[0]
logger.debug(f"Selected temperature for fast regeneration : {random_temperature}")
output = self.model.generate(
output = model.generate(
**inputs.to(device_type),
max_new_tokens=300,
temperature=random_temperature,
Expand All @@ -515,7 +542,7 @@ def generate_sql(
possible_temp_choice = [0.1, 0.3, 0.5, 0.6, 0.75, 0.9, 1.0]
random_temperature = np.random.choice(possible_temp_choice, 1)[0]
logger.debug(f"Selected temperature for diverse beam search: {random_temperature}")
output_re = self.model.generate(
output_re = model.generate(
**inputs.to(device_type),
max_new_tokens=300,
temperature=random_temperature,
Expand All @@ -530,7 +557,7 @@ def generate_sql(
return_dict_in_generate=True,
)

transition_scores = self.model.compute_transition_scores(
transition_scores = model.compute_transition_scores(
output_re.sequences, output_re.scores, output_re.beam_indices, normalize_logits=False
)

Expand All @@ -539,7 +566,7 @@ def generate_sql(
# Sum the True values along axis 1
counts = torch.sum(mask, dim=1)
output_length = inputs.input_ids.shape[1] + counts
length_penalty = self.model.generation_config.length_penalty
length_penalty = model.generation_config.length_penalty
reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty)

# Converting logit scores to prob scores
Expand All @@ -555,7 +582,7 @@ def generate_sql(
)
for idx, sorted_idx in enumerate(prob_sorted_idxs):
_out = output_re.sequences[sorted_idx]
res = self.tokenizer.decode(_out[input_length:], skip_special_tokens=True)
res = tokenizer.decode(_out[input_length:], skip_special_tokens=True)
result = res.replace("table_name", _table_name)
if "LIMIT".lower() not in result.lower():
res = "SELECT " + result.strip() + " LIMIT 100;"
Expand All @@ -570,7 +597,7 @@ def generate_sql(
alternate_queries.append(alt_res)
logger.info(alt_res)

_res = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
_res = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Below is a pre-caution in-case of an error in table name during generation
# COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite
_temp = _res.replace("table_name", table_name).split(";")[0]
Expand Down
Loading

0 comments on commit ab44aad

Please sign in to comment.