From 30f7762c315afd88c1c1c06108c92e267d94f34f Mon Sep 17 00:00:00 2001 From: Pranay Buradkar Date: Fri, 3 Jan 2025 03:34:59 +0530 Subject: [PATCH 1/2] Partition key error of AstraDBCQL fixed. --- .../langflow/components/tools/astradb_cql.py | 79 ++++++++++++++----- 1 file changed, 60 insertions(+), 19 deletions(-) diff --git a/src/backend/base/langflow/components/tools/astradb_cql.py b/src/backend/base/langflow/components/tools/astradb_cql.py index 04f1da824aa3..f310f23c1276 100644 --- a/src/backend/base/langflow/components/tools/astradb_cql.py +++ b/src/backend/base/langflow/components/tools/astradb_cql.py @@ -91,33 +91,69 @@ class AstraDBCQLToolComponent(LCToolComponent): ] def astra_rest(self, args): - headers = {"Accept": "application/json", "X-Cassandra-Token": f"{self.token}"} - astra_url = f"{self.api_endpoint}/api/rest/v2/keyspaces/{self.keyspace}/{self.table_name}/" - key = [] + headers = {"Accept": "application/json", "X-Cassandra-Token": f"{self.token}", "Content-Type": "application/json"} - # Partition keys are mandatory - key = [self.partition_keys[k] for k in self.partition_keys] + astra_url = f"{self.api_endpoint}/api/rest/v2/keyspaces/{self.keyspace}/{self.table_name}/rows" - # Clustering keys are optional - for k in self.clustering_keys: - if k in args: - key.append(args[k]) - elif self.static_filters[k] is not None: - key.append(self.static_filters[k]) + where_clauses = [] - url = f"{astra_url}{'/'.join(key)}?page-size={self.number_of_results}" + for key in self.partition_keys: + if key in args: + where_clauses.append( + {"column": key, "operator": "EQ", "value": args[key]} + ) + elif key in self.static_filters: + where_clauses.append( + {"column": key, "operator": "EQ", "value": self.static_filters[key]} + ) - if self.projection_fields != "*": - url += f"&fields={urllib.parse.quote(self.projection_fields.replace(' ', ''))}" + for key in self.clustering_keys: + clean_key = key[1:] if key.startswith("!") else key + if clean_key in args and args[clean_key] is not None: + where_clauses.append( + {"column": clean_key, "operator": "EQ", "value": args[clean_key]} + ) + elif clean_key in self.static_filters: + where_clauses.append( + { + "column": clean_key, + "operator": "EQ", + "value": self.static_filters[clean_key], + } + ) + + params = { + "page-size": self.number_of_results, + "where": {"filters": where_clauses}, + } - res = requests.request("GET", url=url, headers=headers, timeout=10) + if self.projection_fields != "*": + params["fields"] = [ + field.strip() for field in self.projection_fields.split(",") + ] + + res = requests.request( + "GET", + url=astra_url, + headers=headers, + params={"raw": "true"}, + json=params, + timeout=10, + ) if int(res.status_code) >= HTTPStatus.BAD_REQUEST: return res.text try: res_data = res.json() - return res_data["data"] + if isinstance(res_data, dict) and "data" in res_data: + return res_data["data"] + elif isinstance(res_data, list): + return res_data + elif isinstance(res_data, dict): + return [res_data] + else: + return [] except ValueError: return res.status_code @@ -125,17 +161,18 @@ def create_args_schema(self) -> dict[str, BaseModel]: args: dict[str, tuple[Any, Field]] = {} for key in self.partition_keys: - # Partition keys are mandatory is it doesn't have a static filter if key not in self.static_filters: args[key] = (str, Field(description=self.partition_keys[key])) for key in self.clustering_keys: - # Partition keys are mandatory if has the exclamation mark and doesn't have a static filter if key not in self.static_filters: if key.startswith("!"): # Mandatory args[key[1:]] = (str, Field(description=self.clustering_keys[key])) else: # Optional - args[key] = (str | None, Field(description=self.clustering_keys[key], default=None)) + args[key] = ( + str | None, + Field(description=self.clustering_keys[key], default=None), + ) model = create_model("ToolInput", **args, __base__=BaseModel) return {"ToolInput": model} @@ -172,6 +209,10 @@ def projection_args(self, input_str: str) -> dict: def run_model(self, **args) -> Data | list[Data]: results = self.astra_rest(args) + + if isinstance(results, str): + raise ValueError(f"Error from Astra DB: {results}") + data: list[Data] = [Data(data=doc) for doc in results] self.status = data return results From 7efcbfafa196d88e509a71abfe2a8fe93e8a1ffd Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Thu, 2 Jan 2025 22:28:48 +0000 Subject: [PATCH 2/2] [autofix.ci] apply automated fixes --- .../langflow/components/tools/astradb_cql.py | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/src/backend/base/langflow/components/tools/astradb_cql.py b/src/backend/base/langflow/components/tools/astradb_cql.py index f310f23c1276..26092ff9e66a 100644 --- a/src/backend/base/langflow/components/tools/astradb_cql.py +++ b/src/backend/base/langflow/components/tools/astradb_cql.py @@ -1,4 +1,3 @@ -import urllib from http import HTTPStatus from typing import Any @@ -91,7 +90,11 @@ class AstraDBCQLToolComponent(LCToolComponent): ] def astra_rest(self, args): - headers = {"Accept": "application/json", "X-Cassandra-Token": f"{self.token}", "Content-Type": "application/json"} + headers = { + "Accept": "application/json", + "X-Cassandra-Token": f"{self.token}", + "Content-Type": "application/json", + } astra_url = f"{self.api_endpoint}/api/rest/v2/keyspaces/{self.keyspace}/{self.table_name}/rows" @@ -99,20 +102,14 @@ def astra_rest(self, args): for key in self.partition_keys: if key in args: - where_clauses.append( - {"column": key, "operator": "EQ", "value": args[key]} - ) + where_clauses.append({"column": key, "operator": "EQ", "value": args[key]}) elif key in self.static_filters: - where_clauses.append( - {"column": key, "operator": "EQ", "value": self.static_filters[key]} - ) + where_clauses.append({"column": key, "operator": "EQ", "value": self.static_filters[key]}) for key in self.clustering_keys: clean_key = key[1:] if key.startswith("!") else key if clean_key in args and args[clean_key] is not None: - where_clauses.append( - {"column": clean_key, "operator": "EQ", "value": args[clean_key]} - ) + where_clauses.append({"column": clean_key, "operator": "EQ", "value": args[clean_key]}) elif clean_key in self.static_filters: where_clauses.append( { @@ -128,9 +125,7 @@ def astra_rest(self, args): } if self.projection_fields != "*": - params["fields"] = [ - field.strip() for field in self.projection_fields.split(",") - ] + params["fields"] = [field.strip() for field in self.projection_fields.split(",")] res = requests.request( "GET", @@ -148,12 +143,11 @@ def astra_rest(self, args): res_data = res.json() if isinstance(res_data, dict) and "data" in res_data: return res_data["data"] - elif isinstance(res_data, list): + if isinstance(res_data, list): return res_data - elif isinstance(res_data, dict): + if isinstance(res_data, dict): return [res_data] - else: - return [] + return [] except ValueError: return res.status_code