Skip to content

Commit

Permalink
Update database.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Emanuele Guidotti committed Feb 7, 2024
1 parent fe04fcf commit 99f34fc
Showing 1 changed file with 131 additions and 61 deletions.
192 changes: 131 additions & 61 deletions bornrule/sql/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,16 @@ def _sql_predict(self, items, cache):
return f"""
{self._sql_with_HW_jk(cache)},
{self._sql_X_nj(items)},
{self._sql_U_nk()},
{self._sql_R_nk()}
{self._sql_X_nj_a()},
{self._sql_HWX_nk()},
R_nk AS (
SELECT
{self.n},
{self.k},
ROW_NUMBER() OVER(PARTITION BY {self.n} ORDER BY {self.w} DESC) AS {self.w}
FROM
HWX_nk
)
SELECT
{self.n},
{self.k}
Expand All @@ -410,6 +418,8 @@ def _sql_predict_proba(self, items, cache):
return f"""
{self._sql_with_HW_jk(cache)},
{self._sql_X_nj(items)},
{self._sql_X_nj_a()},
{self._sql_HWX_nk()},
{self._sql_U_nk()},
{self._sql_U_n()}
SELECT
Expand Down Expand Up @@ -455,52 +465,62 @@ def _sql_with_HW_jk(self, cache):
{self._sql_ABH()},
{self._sql_P_j()},
{self._sql_P_k()},
{self._sql_P_j_b()},
{self._sql_P_k_b()},
{self._sql_W_jk()},
{self._sql_W_j()},
{self._sql_H_jk()},
{self._sql_LN()},
{self._sql_H_j()},
{self._sql_H_j_h()},
{self._sql_W_jk_a()},
{self._sql_HW_jk()}
"""

def _sql_C_nk(self, items):
return f"""
C_nk AS (
SELECT
C.{self.n},
C.{self.k},
C.{self.w}
FROM
({self._sql_transform(items['class'], concat=False, name=self.k)}) AS C,
({items['where']}) AS N
WHERE
C.{self.n} = N.{self.n} AND
C.{self.k} IS NOT NULL
)
"""

def _sql_X_nj(self, items):
if not isinstance(items, dict):
return f"X_nj AS (SELECT * FROM {items})"

return f"""
N AS (
{items['where']}
),
X AS ({
' UNION ALL '.join([
self._sql_transform(feature, concat=True, name=self.j)
for feature in items['features']])
}),
X_nj AS (
SELECT
X.{self.n},
X.{self.j},
X.{self.w}
FROM
({' UNION ALL '.join([
self._sql_transform(feature, concat=True, name=self.j)
for feature in items['features']
])}) AS X,
({items['where']}) AS N
X, N
WHERE
X.{self.n} = N.{self.n} AND
X.{self.j} IS NOT NULL
)
"""

def _sql_C_nk(self, items):
return f"""
C AS ({
self._sql_transform(items['class'], concat=False, name=self.k)
}),
C_nk AS (
SELECT
C.{self.n},
C.{self.k},
C.{self.w}
FROM
C, N
WHERE
C.{self.n} = N.{self.n} AND
C.{self.k} IS NOT NULL
)
"""

def _sql_X_n(self):
return f"""
X_n AS (
Expand Down Expand Up @@ -658,21 +678,40 @@ def _sql_P_j(self):
)
"""

def _sql_P_k_b(self):
return f"""
P_k_b AS (
SELECT
{self.k},
{self.POW}(P_k.{self.w}, - ABH.b) AS {self.w}
FROM
P_k, ABH
)
"""

def _sql_P_j_b(self):
return f"""
P_j_b AS (
SELECT
{self.j},
{self.POW}(P_j.{self.w}, ABH.b - 1) AS {self.w}
FROM
P_j, ABH
)
"""

def _sql_W_jk(self):
return f"""
W_jk AS (
SELECT
{self.table_corpus}.{self.j},
{self.table_corpus}.{self.k},
{self.table_corpus}.{self.w}
* {self.POW}(P_k.{self.w}, - ABH.b)
* {self.POW}(P_j.{self.w}, ABH.b - 1)
AS {self.w}
{self.table_corpus}.{self.w} * P_j_b.{self.w} * P_k_b.{self.w} AS {self.w}
FROM
{self.table_corpus}, P_j, P_k, ABH
{self.table_corpus}, P_j_b, P_k_b
WHERE
{self.table_corpus}.{self.j} = P_j.{self.j} AND
{self.table_corpus}.{self.k} = P_k.{self.k}
{self.table_corpus}.{self.j} = P_j_b.{self.j} AND
{self.table_corpus}.{self.k} = P_k_b.{self.k}
)
"""

Expand Down Expand Up @@ -708,70 +747,101 @@ def _sql_H_j(self):
H_j AS (
SELECT
H_jk.{self.j},
1 + {self.SUM}(
H_jk.{self.w} * {self.LOG}(H_jk.{self.w}) / LN.{self.w}
) AS {self.w}
{self.SUM}(H_jk.{self.w} * {self.LOG}(H_jk.{self.w})) AS {self.w}
FROM
H_jk, LN
H_jk
GROUP BY
{self.j}
)
"""

def _sql_HW_jk(self):
def _sql_H_j_h(self):
return f"""
HW_jk AS (
H_j_h AS (
SELECT
H_j.{self.j},
{self.POW}(1 + H_j.{self.w} / LN.{self.w}, ABH.h) AS {self.w}
FROM
H_j, ABH, LN
)
"""

def _sql_W_jk_a(self):
return f"""
W_jk_a AS (
SELECT
W_jk.{self.j},
W_jk.{self.k},
{self.POW}(W_jk.{self.w}, ABH.a) * {self.POW}(H_j.{self.w}, ABH.h) AS {self.w}
{self.POW}(W_jk.{self.w}, ABH.a) AS {self.w}
FROM
W_jk, ABH
)
"""

def _sql_HW_jk(self):
return f"""
HW_jk AS (
SELECT
W_jk_a.{self.j},
W_jk_a.{self.k},
W_jk_a.{self.w} * H_j_h.{self.w} AS {self.w}
FROM
W_jk, H_j, ABH
W_jk_a, H_j_h
WHERE
W_jk.{self.j} = H_j.{self.j}
W_jk_a.{self.j} = H_j_h.{self.j}
)
"""

def _sql_U_nk(self):
def _sql_X_nj_a(self):
return f"""
U_nk AS (
X_nj_a AS (
SELECT
X_nj.{self.n},
X_nj.{self.j},
{self.POW}(X_nj.{self.w}, ABH.a) AS {self.w}
FROM
X_nj, ABH
)
"""

def _sql_HWX_nk(self):
return f"""
HWX_nk AS (
SELECT
X_nj.{self.n},
X_nj_a.{self.n},
HW_jk.{self.k},
{self.POW}(
{self.SUM}(HW_jk.{self.w} * {self.POW}(X_nj.{self.w}, ABH.a)),
1 / ABH.a
) AS {self.w}
{self.SUM}(HW_jk.{self.w} * X_nj_a.{self.w}) AS {self.w}
FROM
X_nj, HW_jk, ABH
X_nj_a, HW_jk
WHERE
X_nj.{self.j} = HW_jk.{self.j}
X_nj_a.{self.j} = HW_jk.{self.j}
GROUP BY
{self.n}, {self.k}, ABH.a
X_nj_a.{self.n},
HW_jk.{self.k}
)
"""

def _sql_U_n(self):
def _sql_U_nk(self):
return f"""
U_n AS (
U_nk AS (
SELECT
{self.n},
{self.SUM}({self.w}) AS {self.w}
HWX_nk.{self.n},
HWX_nk.{self.k},
{self.POW}(HWX_nk.{self.w}, 1 / ABH.a) AS {self.w}
FROM
U_nk
GROUP BY
{self.n}
HWX_nk, ABH
)
"""
def _sql_R_nk(self):

def _sql_U_n(self):
return f"""
R_nk AS (
U_n AS (
SELECT
{self.n},
{self.k},
ROW_NUMBER() OVER(PARTITION BY {self.n} ORDER BY {self.w} DESC) AS {self.w}
{self.SUM}({self.w}) AS {self.w}
FROM
U_nk
GROUP BY
{self.n}
)
"""

0 comments on commit 99f34fc

Please sign in to comment.