diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 292b2839..119bd622 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -480,6 +480,7 @@ def ask( question: Union[str, None] = None, print_results: bool = True, auto_train: bool = True, + visualize: bool = True, # if False, will not generate plotly code ) -> Union[ Tuple[ Union[str, None], @@ -499,7 +500,7 @@ def ask( if print_results: try: - Code = __import__("IPython.display", fromlist=["Code"]).Code + Code = __import__("IPython.display", fromList=["Code"]).Code display(Code(sql)) except Exception as e: print(sql) @@ -515,19 +516,12 @@ def ask( return sql, None, None try: - if self.is_sql_valid(sql) is False: - print("SQL is not valid, please try again.") - if print_results: - return None - else: - return sql, None, None - df = self.run_sql(sql) if print_results: try: display = __import__( - "IPython.display", fromlist=["display"] + "IPython.display", fromList=["display"] ).display display(df) except Exception as e: @@ -535,32 +529,33 @@ def ask( if len(df) > 0 and auto_train: self.add_question_sql(question=question, sql=sql) - - try: - plotly_code = self.generate_plotly_code( - question=question, - sql=sql, - df_metadata=f"Running df.dtypes gives:\n {df.dtypes}", - ) - fig = self.get_plotly_figure(plotly_code=plotly_code, df=df) - if print_results: - try: - display = __import__( - "IPython.display", fromlist=["display"] - ).display - Image = __import__("IPython.display", fromlist=["Image"]).Image - img_bytes = fig.to_image(format="png", scale=2) - display(Image(img_bytes)) - except Exception as e: - fig.show() - except Exception as e: - # Print stack trace - traceback.print_exc() - print("Couldn't run plotly code: ", e) - if print_results: - return None - else: - return sql, df, None + # Only generate plotly code if visualize is True + if visualize: + try: + plotly_code = self.generate_plotly_code( + question=question, + sql=sql, + df_metadata=f"Running df.dtypes gives:\n {df.dtypes}", + ) + fig = self.get_plotly_figure(plotly_code=plotly_code, df=df) + if print_results: + try: + display = __import__( + "IPython.display", fromlist=["display"] + ).display + Image = __import__("IPython.display", fromlist=["Image"]).Image + img_bytes = fig.to_image(format="png", scale=2) + display(Image(img_bytes)) + except Exception as e: + fig.show() + except Exception as e: + # Print stack trace + traceback.print_exc() + print("Couldn't run plotly code: ", e) + if print_results: + return None + else: + return sql, df, None except Exception as e: print("Couldn't run sql: ", e)