Skip to content

Commit

Permalink
add a visualize parm for the ask function
Browse files Browse the repository at this point in the history
  • Loading branch information
MoonTidef committed Jan 18, 2024
1 parent c804513 commit c1a018a
Showing 1 changed file with 30 additions and 35 deletions.
65 changes: 30 additions & 35 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def ask(
question: Union[str, None] = None,
print_results: bool = True,
auto_train: bool = True,
visualize: bool = True, # if False, will not generate plotly code
) -> Union[
Tuple[
Union[str, None],
Expand All @@ -499,7 +500,7 @@ def ask(

if print_results:
try:
Code = __import__("IPython.display", fromlist=["Code"]).Code
Code = __import__("IPython.display", fromList=["Code"]).Code
display(Code(sql))
except Exception as e:
print(sql)
Expand All @@ -515,52 +516,46 @@ def ask(
return sql, None, None

try:
if self.is_sql_valid(sql) is False:
print("SQL is not valid, please try again.")
if print_results:
return None
else:
return sql, None, None

df = self.run_sql(sql)

if print_results:
try:
display = __import__(
"IPython.display", fromlist=["display"]
"IPython.display", fromList=["display"]
).display
display(df)
except Exception as e:
print(df)

if len(df) > 0 and auto_train:
self.add_question_sql(question=question, sql=sql)

try:
plotly_code = self.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
fig = self.get_plotly_figure(plotly_code=plotly_code, df=df)
if print_results:
try:
display = __import__(
"IPython.display", fromlist=["display"]
).display
Image = __import__("IPython.display", fromlist=["Image"]).Image
img_bytes = fig.to_image(format="png", scale=2)
display(Image(img_bytes))
except Exception as e:
fig.show()
except Exception as e:
# Print stack trace
traceback.print_exc()
print("Couldn't run plotly code: ", e)
if print_results:
return None
else:
return sql, df, None
# Only generate plotly code if visualize is True
if visualize:
try:
plotly_code = self.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
fig = self.get_plotly_figure(plotly_code=plotly_code, df=df)
if print_results:
try:
display = __import__(
"IPython.display", fromlist=["display"]
).display
Image = __import__("IPython.display", fromlist=["Image"]).Image
img_bytes = fig.to_image(format="png", scale=2)
display(Image(img_bytes))
except Exception as e:
fig.show()
except Exception as e:
# Print stack trace
traceback.print_exc()
print("Couldn't run plotly code: ", e)
if print_results:
return None
else:
return sql, df, None

except Exception as e:
print("Couldn't run sql: ", e)
Expand Down

0 comments on commit c1a018a

Please sign in to comment.