Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Storing outcomes of similarity functions across a comparison to avoid re-compute #2540

Open
lamaeldo opened this issue Dec 6, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@lamaeldo
Copy link

lamaeldo commented Dec 6, 2024

Is your proposal related to a problem?

When building custome comparisons, it is common to use a level of comparison with continuous outcomes several times with changing threshold (For Example Jaro-Winkler > 0.9, > 0.8 > 0.7). Currently, we have to actually compute the similarity function for each level of comparison, which is inneficient

Describe the solution you'd like

For a given comparison, it would make sense for the value of a given similarity function to be stored until the end of the comparison, and for it to be retrieved by subsequent levels, if they use the same simlilarity function.

@lamaeldo lamaeldo added the enhancement New feature or request label Dec 6, 2024
@lamaeldo lamaeldo changed the title [FEAT] <title> Storing outcomes of similarity functions across a comparison to avoid re-compute Dec 6, 2024
@RobinL
Copy link
Member

RobinL commented Dec 14, 2024

Starting to experiment with this as follows:

import sqlglot
from sqlglot import exp

sql = """
SELECT
    CASE
        WHEN levenshtein(name_l, name_r) < 2 THEN 0
        WHEN levenshtein(name_l, name_r) < 4 THEN 1
        WHEN levenshtein(name_l, name_r) < 6 THEN 2
        WHEN jaro(name_l, name_r) < 0.9 THEN 2
        WHEN jaro(name_l, name_r) < 0.8 THEN 2
        ELSE 3
    END as similarity_bin
FROM joined_names
"""


expression = sqlglot.parse_one(sql)

function_counts = {}
for case_expr in expression.find_all(exp.Case):
    for when_expr in case_expr.find_all(exp.If):
        for func in when_expr.find_all(exp.Func):
            func_sql = func.sql()
            function_counts[func_sql] = function_counts.get(func_sql, 0) + 1


repeated_functions = {func for func, count in function_counts.items() if count > 1}


def transform(node):
    if isinstance(node, exp.Func) and node.sql() in repeated_functions:
        cleaned = "".join(c if c.isalnum() else "_" for c in node.sql().lower())
        return exp.Literal.string(cleaned)
    return node


transformed = expression.transform(transform)
print(transformed.sql())

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants