Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add standardized and overridable logging #234

Closed
wants to merge 9 commits into from
16 changes: 8 additions & 8 deletions src/vanna/ZhipuAI/ZhipuAI_Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from typing import List
import pandas as pd

from ..logger import get_logger

_logger = get_logger()


class ZhipuAI_Chat(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)
Expand Down Expand Up @@ -108,7 +113,7 @@ def get_sql_prompt(

for example in question_sql_list:
if example is None:
print("example is None")
_logger.info("Example is None")
else:
if example is not None and "question" in example and "sql" in example:
message_log.append(ZhipuAI_Chat.user_message(example["question"]))
Expand Down Expand Up @@ -220,19 +225,14 @@ def submit_prompt(
if len(prompt) == 0:
raise Exception("Prompt is empty")

client = ZhipuAI(api_key=self.api_key) # 填写您自己的APIKey
client = ZhipuAI(api_key=self.api_key) # set your own API key
response = client.chat.completions.create(
model="glm-4", # 填写需要调用的模型名称
model="glm-4", # set your own model name
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=stop,
messages=prompt,
)
# print(prompt)

# print(response)

# print(f"Cost {response.usage.total_tokens} token")

return response.choices[0].message.content
9 changes: 6 additions & 3 deletions src/vanna/ZhipuAI/ZhipuAI_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from chromadb import Documents, EmbeddingFunction, Embeddings
from ..base import VannaBase

from ..logger import get_logger

_logger = get_logger()


class ZhipuAI_Embeddings(VannaBase):
"""
[future functionality] This function is used to generate embeddings from ZhipuAI.
Expand Down Expand Up @@ -60,7 +65,7 @@ def __call__(self, input: Documents) -> Embeddings:
# Replace newlines, which can negatively affect performance.
input = [t.replace("\n", " ") for t in input]
all_embeddings = []
print(f"Generating embeddings for {len(input)} documents")
_logger.info(f"Generating embeddings for {len(input)} documents")

# Iterating over each document for individual API calls
for document in input:
Expand All @@ -69,10 +74,8 @@ def __call__(self, input: Documents) -> Embeddings:
model=self.model_name,
input=document
)
# print(response)
embedding = response.data[0].embedding
all_embeddings.append(embedding)
# print(f"Cost required: {response.usage.total_tokens}")
except Exception as e:
raise ValueError(f"Error generating embedding for document: {e}")

Expand Down
99 changes: 54 additions & 45 deletions src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@
SQLRemoveError,
ValidationError,
)
from .logger import initialize_logger, setVerbosity
from .types import (
AccuracyStats,
ApiKey,
Expand Down Expand Up @@ -184,13 +185,15 @@
)
from .utils import sanitize_model_name, validate_config_path

_logger = initialize_logger()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should not configure the output of the logger at the global level. Then as soon as someone inports this module, the logger will start printing to stdout. Instead, configure it as one of the first things in the actual main() function.

Also, if this is the only place where initialize_logger() is used, it doesn't need to be in that distant module, but just define it directly here.

Copy link
Contributor Author

@andreped andreped Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then as soon as someone inports this module, the logger will start printing to stdout

I thought the goal was to standardize all prints to the new logging solution? If you import vanna using my fix, all prints will follow this new standard. You are also free to change it yourself from outside the framework, if you'd like, using vn.setVerbosity("DEBUG"). That sounded like behaviour that we would like, no?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I see what you're saying. I think it just comes down to what we want as the default:

  • default to not configuring the logger, and then users opt-in to logging with logging.basicConfig(...) or the more finetipped logging.getLogger("vanna").do_stuff(...).
  • default to configuring the logger, and have users opt-out.

Currently the behavior is default-on. But arguably the reason for that is just because we can't turn print() off. Some libraries go for default on (eg splink), others go for default off, flask is more nuanced. I prefer quietness, so I prefer to not do anything in library code, and let the end app developer configure logging. But I see the pros/cons of all 3.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quote from there:

Note It is strongly advised that you do not add any handlers other than NullHandler to your library’s loggers. This is because the configuration of handlers is the prerogative of the application developer who uses your library. The application developer knows their target audience and what handlers are most appropriate for their application: if you add handlers ‘under the hood’, you might well interfere with their ability to carry out unit tests and deliver logs which suit their requirements.

So at least I think we should lean on the conservative side of adding handlers, maybe doing something similar to flask?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all CLI/main entrypoints of this lib, we should configure logging of course, so behavior will stay similar to today. But if I'm using vanna as a lib, this is when behavior would change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see what you mean. It was a little unclear to me what the you wanted to solution to look like. But still, we do not want print() inside the framework, as we cannot mute these. A good example is that when training Vanna, the documents and SQLs will be printed directly in the console. This is unfortunate for production use.

I'm fine with removing the _logger solution given the new insight, but I guess I could revert to just doing logging.info() (and similar) and then it is up to the end user to change the logging behaviour.

I can make an attempt at this later today. A bit busy atm.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really appreciate the work here. I am rereading my review and it sounds a bit harsh, I could have been more graceful, so I'm sorry if I came off as a jerk.

Maybe wait for @zainhoda to chime in before you go through the rewrite? They might have a different idea totally for what they want the result to look like.

Copy link
Contributor Author

@andreped andreped Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really appreciate the work here. I am rereading my review and it sounds a bit harsh, I could have been more graceful, so I'm sorry if I came off as a jerk.

@NickCrews No worries! I figured it was late in the day ;) Have done some open-source and learned to take some criticism on my bad ideas and code suggestions. So its all good!

Maybe wait for @zainhoda to chime in before you go through the rewrite? They might have a different idea totally for what they want the result to look like.

Yeah, that sounds like a good idea. I am in no rush. Just ping me if I fail to see an update :]

But it is a little unfortunate that when training Vanna, the training data is printed directly in the console. Not ideal for our production use case...


api_key: Union[str, None] = None # API key for Vanna.AI

fig_as_img: bool = False # Whether or not to return Plotly figures as images

run_sql: Union[
Callable[[str], pd.DataFrame], None
] = None # Function to convert SQL to a Pandas DataFrame
run_sql: Union[Callable[[str], pd.DataFrame], None] = (
None # Function to convert SQL to a Pandas DataFrame
)
"""
**Example**
```python
Expand Down Expand Up @@ -233,7 +236,7 @@ def __rpc_call(method, params):
raise ImproperlyConfigured(
"model not set. Use vn.set_model(...) to set the model to use."
)

if method == "list_orgs":
headers = {
"Content-Type": "application/json",
Expand Down Expand Up @@ -424,7 +427,7 @@ def add_user_to_model(model: str, email: str, is_admin: bool) -> bool:
status = Status(**d["result"])

if not status.success:
print(status.message)
_logger.error(status.message)

return status.success

Expand Down Expand Up @@ -860,7 +863,7 @@ def get_training_plan_experimental(

if use_historical_queries:
try:
print("Trying query history")
_logger.info("Trying query history")
df_history = run_sql(
""" select * from table(information_schema.query_history(result_limit => 5000)) order by start_time"""
)
Expand Down Expand Up @@ -901,7 +904,7 @@ def get_training_plan_experimental(
)

except Exception as e:
print(e)
_logger.error(e)

databases = __get_databases()

Expand All @@ -912,7 +915,7 @@ def get_training_plan_experimental(
try:
df_tables = __get_information_schema_tables(database=database)

print(f"Trying INFORMATION_SCHEMA.COLUMNS for {database}")
_logger.info(f"Trying INFORMATION_SCHEMA.COLUMNS for {database}")
df_columns = run_sql(f"SELECT * FROM {database}.INFORMATION_SCHEMA.COLUMNS")

for schema in df_tables["TABLE_SCHEMA"].unique().tolist():
Expand Down Expand Up @@ -959,18 +962,18 @@ def get_training_plan_experimental(
)

except Exception as e:
print(e)
_logger.error(e)
pass
except Exception as e:
print(e)
_logger.error(e)

# try:
# print("Trying SHOW TABLES")
# _logger.info("Trying SHOW TABLES")
# df_f = run_sql("SHOW TABLES")

# for schema in df_f.schema_name.unique():
# try:
# print(f"Trying GET_DDL for {schema}")
# _logger.info(f"Trying GET_DDL for {schema}")
# ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')")

# plan._plan.append(TrainingPlanItem(
Expand All @@ -983,13 +986,13 @@ def get_training_plan_experimental(
# pass
# except:
# try:
# print("Trying INFORMATION_SCHEMA.TABLES")
# _logger.info("Trying INFORMATION_SCHEMA.TABLES")
# df = run_sql("SELECT * FROM INFORMATION_SCHEMA.TABLES")

# breakpoint()

# try:
# print("Trying SCHEMATA")
# _logger.info("Trying SCHEMATA")
# df_schemata = run_sql("SELECT * FROM region-us.INFORMATION_SCHEMA.SCHEMATA")

# for schema in df_schemata.schema_name.unique():
Expand Down Expand Up @@ -1060,27 +1063,27 @@ def train(
)

if documentation:
print("Adding documentation....")
_logger.info("Adding documentation....")
return add_documentation(documentation)

if sql:
if question is None:
question = generate_question(sql)
print("Question generated with sql:", question, "\nAdding SQL...")
_logger.info(f"Question generated with sql: {question}\nAdding SQL...")
return add_sql(question=question, sql=sql)

if ddl:
print("Adding ddl:", ddl)
_logger.info(f"Adding ddl: {ddl}")
return add_ddl(ddl)

if json_file:
validate_config_path(json_file)
with open(json_file, "r") as js_file:
data = json.load(js_file)
print("Adding Questions And SQLs using file:", json_file)
_logger.info(f"Adding Questions And SQLs using file: {json_file}")
for question in data:
if not add_sql(question=question["question"], sql=question["answer"]):
print(
_logger.error(
f"Not able to add sql for question: {question['question']} from {json_file}"
)
return False
Expand All @@ -1093,34 +1096,36 @@ def train(
for statement in sql_statements:
if "CREATE TABLE" in statement:
if add_ddl(statement):
print("ddl Added!")
_logger.info("ddl Added!")
return True
print("Not able to add DDL")
_logger.error("Not able to add DDL")
return False
else:
question = generate_question(sql=statement)
if add_sql(question=question, sql=statement):
print("SQL added!")
_logger.info("SQL added!")
return True
print("Not able to add sql.")
_logger.error("Not able to add sql.")
return False
return False

if plan:
for item in plan._plan:
if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL:
if not add_ddl(item.item_value):
print(f"Not able to add ddl for {item.item_group}")
_logger.error(f"Not able to add ddl for {item.item_group}")
return False
elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS:
if not add_documentation(item.item_value):
print(
_logger.error(
f"Not able to add documentation for {item.item_group}.{item.item_name}"
)
return False
elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL:
if not add_sql(question=item.item_name, sql=item.item_value):
print(f"Not able to add sql for {item.item_group}.{item.item_name}")
_logger.error(
f"Not able to add sql for {item.item_group}.{item.item_name}"
)
return False


Expand Down Expand Up @@ -1438,18 +1443,18 @@ def ask(
try:
sql = generate_sql(question=question)
except Exception as e:
print(e)
_logger.error(e)
return None, None, None, None

if print_results:
try:
Code = __import__("IPython.display", fromlist=["Code"]).Code
display(Code(sql))
except Exception as e:
print(sql)
_logger.error(sql)

if run_sql is None:
print("If you want to run the SQL query, provide a vn.run_sql function.")
_logger.info("If you want to run the SQL query, provide a vn.run_sql function.")

if print_results:
return None
Expand All @@ -1464,7 +1469,7 @@ def ask(
display = __import__("IPython.display", fromlist=["display"]).display
display(df)
except Exception as e:
print(df)
_logger.error(df)

if len(df) > 0 and auto_train:
add_sql(question=question, sql=sql, tag=types.QuestionCategory.SQL_RAN)
Expand Down Expand Up @@ -1513,7 +1518,7 @@ def ask(
).Markdown
display(Markdown(md))
except Exception as e:
print(md)
_logger.error(md)

if print_results:
return None
Expand All @@ -1526,16 +1531,17 @@ def ask(
return sql, df, fig, None

except Exception as e:
# Print stack trace
traceback.print_exc()
print("Couldn't run plotly code: ", e)
_logger.error(traceback.print_exc())
_logger.error(f"Couldn't run plotly code: {e}")

if print_results:
return None
else:
return sql, df, None, None

except Exception as e:
print("Couldn't run sql: ", e)
_logger.error(f"Couldn't run sql: {e}")

if print_results:
return None
else:
Expand Down Expand Up @@ -1648,7 +1654,7 @@ def get_results(cs, default_database: str, sql: str) -> pd.DataFrame:
Returns:
pd.DataFrame: The results of the SQL query.
"""
print("`vn.get_results()` is deprecated. Use `vn.run_sql()` instead.")
_logger.warning("`vn.get_results()` is deprecated. Use `vn.run_sql()` instead.")
warnings.warn("`vn.get_results()` is deprecated. Use `vn.run_sql()` instead.")

cs.execute(f"USE DATABASE {default_database}")
Expand Down Expand Up @@ -2015,7 +2021,9 @@ def connect_to_postgres(

def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
try:
with conn.cursor() as cs: # Using a with statement to manage the cursor lifecycle
with (
conn.cursor() as cs
): # Using a with statement to manage the cursor lifecycle
cs.execute(sql)
results = cs.fetchall()
df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description])
Expand Down Expand Up @@ -2074,14 +2082,14 @@ def connect_to_bigquery(cred_file_path: str = None, project_id: str = None):
except Exception as e:
raise ImproperlyConfigured(e)
else:
print("Not using Google Colab.")
_logger.info("Not using Google Colab.")

conn = None

try:
conn = bigquery.Client(project=project_id)
except:
print("Could not found any google cloud implicit credentials")
_logger.error("Could not found any google cloud implicit credentials")

if cred_file_path:
# Validate file path and pemissions
Expand Down Expand Up @@ -2122,7 +2130,8 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
global run_sql
run_sql = run_sql_bigquery

def connect_to_duckdb(url: str="memory", init_sql: str = None):

def connect_to_duckdb(url: str = "memory", init_sql: str = None):
"""
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql]

Expand All @@ -2141,13 +2150,13 @@ def connect_to_duckdb(url: str="memory", init_sql: str = None):
" run command: \npip install vanna[duckdb]"
)
# URL of the database to download
if url==":memory:" or url=="":
path=":memory:"
if url == ":memory:" or url == "":
path = ":memory:"
else:
# Path to save the downloaded database
print(os.path.exists(url))
_logger.info(os.path.exists(url))
if os.path.exists(url):
path=url
path = url
else:
path = os.path.basename(urlparse(url).path)
# Download the database if it doesn't exist
Expand All @@ -2166,4 +2175,4 @@ def run_sql_duckdb(sql: str):
return conn.query(sql).to_df()

global run_sql
run_sql = run_sql_duckdb
run_sql = run_sql_duckdb
Loading
Loading