Skip to content

Commit

Permalink
support hive create function (#500)
Browse files Browse the repository at this point in the history
* support hive create function:
Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL-CreateFunction

* Pylint

* func: get_switch_by_create_query set to private method
  • Loading branch information
MiuNice authored Jun 15, 2024
1 parent ece5ace commit dbaa329
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
1 change: 1 addition & 0 deletions sql_metadata/keywords_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class TokenType(str, Enum):
"CREATETABLE": QueryType.CREATE,
"ALTERTABLE": QueryType.ALTER,
"DROPTABLE": QueryType.DROP,
"CREATEFUNCTION": QueryType.CREATE,
}

# all the keywords we care for - rest is ignored in assigning
Expand Down
20 changes: 19 additions & 1 deletion sql_metadata/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def query_type(self) -> str:
)
.position
)
if tokens[index].normalized in ["CREATE", "ALTER", "DROP"]:
if tokens[index].normalized == "CREATE":
switch = self._get_switch_by_create_query(tokens, index)
elif tokens[index].normalized in ("ALTER", "DROP"):
switch = tokens[index].normalized + tokens[index + 1].normalized
else:
switch = tokens[index].normalized
Expand Down Expand Up @@ -1079,3 +1081,19 @@ def _flatten_sqlparse(self):
yield tok
else:
yield token

@staticmethod
def _get_switch_by_create_query(tokens: List[SQLToken], index: int) -> str:
"""
Return the switch that creates query type.
"""
switch = tokens[index].normalized + tokens[index + 1].normalized

# Hive CREATE FUNCTION
if any(
index + i < len(tokens) and tokens[index + i].normalized == "FUNCTION"
for i in (1, 2)
):
switch = "CREATEFUNCTION"

return switch
19 changes: 19 additions & 0 deletions test/test_query_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,22 @@ def test_multiple_redundant_parentheses_create():
"""
parser = Parser(query)
assert parser.query_type == QueryType.CREATE


def test_hive_create_function():
query = """
CREATE FUNCTION simple_udf AS 'com.example.hive.udf.SimpleUDF'
USING JAR 'hdfs:///user/hive/udfs/simple-udf.jar'
WITH SERDEPROPERTIES (
"hive.udf.param1"="value1",
"hive.udf.param2"="value2"
);
"""
parser = Parser(query)
assert parser.query_type == QueryType.CREATE

query = """
CREATE TEMPORARY FUNCTION myudf AS 'com.udf.myudf';
"""
parser = Parser(query)
assert parser.query_type == QueryType.CREATE

0 comments on commit dbaa329

Please sign in to comment.