diff --git a/bornrule/sql/database.py b/bornrule/sql/database.py index 4d5d36e..0978321 100644 --- a/bornrule/sql/database.py +++ b/bornrule/sql/database.py @@ -395,8 +395,16 @@ def _sql_predict(self, items, cache): return f""" {self._sql_with_HW_jk(cache)}, {self._sql_X_nj(items)}, - {self._sql_U_nk()}, - {self._sql_R_nk()} + {self._sql_X_nj_a()}, + {self._sql_HWX_nk()}, + R_nk AS ( + SELECT + {self.n}, + {self.k}, + ROW_NUMBER() OVER(PARTITION BY {self.n} ORDER BY {self.w} DESC) AS {self.w} + FROM + HWX_nk + ) SELECT {self.n}, {self.k} @@ -410,6 +418,8 @@ def _sql_predict_proba(self, items, cache): return f""" {self._sql_with_HW_jk(cache)}, {self._sql_X_nj(items)}, + {self._sql_X_nj_a()}, + {self._sql_HWX_nk()}, {self._sql_U_nk()}, {self._sql_U_n()} SELECT @@ -455,52 +465,62 @@ def _sql_with_HW_jk(self, cache): {self._sql_ABH()}, {self._sql_P_j()}, {self._sql_P_k()}, + {self._sql_P_j_b()}, + {self._sql_P_k_b()}, {self._sql_W_jk()}, {self._sql_W_j()}, {self._sql_H_jk()}, {self._sql_LN()}, {self._sql_H_j()}, + {self._sql_H_j_h()}, + {self._sql_W_jk_a()}, {self._sql_HW_jk()} """ - def _sql_C_nk(self, items): - return f""" - C_nk AS ( - SELECT - C.{self.n}, - C.{self.k}, - C.{self.w} - FROM - ({self._sql_transform(items['class'], concat=False, name=self.k)}) AS C, - ({items['where']}) AS N - WHERE - C.{self.n} = N.{self.n} AND - C.{self.k} IS NOT NULL - ) - """ - def _sql_X_nj(self, items): if not isinstance(items, dict): return f"X_nj AS (SELECT * FROM {items})" return f""" + N AS ( + {items['where']} + ), + X AS ({ + ' UNION ALL '.join([ + self._sql_transform(feature, concat=True, name=self.j) + for feature in items['features']]) + }), X_nj AS ( SELECT X.{self.n}, X.{self.j}, X.{self.w} FROM - ({' UNION ALL '.join([ - self._sql_transform(feature, concat=True, name=self.j) - for feature in items['features'] - ])}) AS X, - ({items['where']}) AS N + X, N WHERE X.{self.n} = N.{self.n} AND X.{self.j} IS NOT NULL ) """ + def _sql_C_nk(self, items): + return f""" + C AS ({ + self._sql_transform(items['class'], concat=False, name=self.k) + }), + C_nk AS ( + SELECT + C.{self.n}, + C.{self.k}, + C.{self.w} + FROM + C, N + WHERE + C.{self.n} = N.{self.n} AND + C.{self.k} IS NOT NULL + ) + """ + def _sql_X_n(self): return f""" X_n AS ( @@ -658,21 +678,40 @@ def _sql_P_j(self): ) """ + def _sql_P_k_b(self): + return f""" + P_k_b AS ( + SELECT + {self.k}, + {self.POW}(P_k.{self.w}, - ABH.b) AS {self.w} + FROM + P_k, ABH + ) + """ + + def _sql_P_j_b(self): + return f""" + P_j_b AS ( + SELECT + {self.j}, + {self.POW}(P_j.{self.w}, ABH.b - 1) AS {self.w} + FROM + P_j, ABH + ) + """ + def _sql_W_jk(self): return f""" W_jk AS ( SELECT {self.table_corpus}.{self.j}, {self.table_corpus}.{self.k}, - {self.table_corpus}.{self.w} - * {self.POW}(P_k.{self.w}, - ABH.b) - * {self.POW}(P_j.{self.w}, ABH.b - 1) - AS {self.w} + {self.table_corpus}.{self.w} * P_j_b.{self.w} * P_k_b.{self.w} AS {self.w} FROM - {self.table_corpus}, P_j, P_k, ABH + {self.table_corpus}, P_j_b, P_k_b WHERE - {self.table_corpus}.{self.j} = P_j.{self.j} AND - {self.table_corpus}.{self.k} = P_k.{self.k} + {self.table_corpus}.{self.j} = P_j_b.{self.j} AND + {self.table_corpus}.{self.k} = P_k_b.{self.k} ) """ @@ -708,70 +747,101 @@ def _sql_H_j(self): H_j AS ( SELECT H_jk.{self.j}, - 1 + {self.SUM}( - H_jk.{self.w} * {self.LOG}(H_jk.{self.w}) / LN.{self.w} - ) AS {self.w} + {self.SUM}(H_jk.{self.w} * {self.LOG}(H_jk.{self.w})) AS {self.w} FROM - H_jk, LN + H_jk GROUP BY {self.j} ) """ - def _sql_HW_jk(self): + def _sql_H_j_h(self): return f""" - HW_jk AS ( + H_j_h AS ( + SELECT + H_j.{self.j}, + {self.POW}(1 + H_j.{self.w} / LN.{self.w}, ABH.h) AS {self.w} + FROM + H_j, ABH, LN + ) + """ + + def _sql_W_jk_a(self): + return f""" + W_jk_a AS ( SELECT W_jk.{self.j}, W_jk.{self.k}, - {self.POW}(W_jk.{self.w}, ABH.a) * {self.POW}(H_j.{self.w}, ABH.h) AS {self.w} + {self.POW}(W_jk.{self.w}, ABH.a) AS {self.w} + FROM + W_jk, ABH + ) + """ + + def _sql_HW_jk(self): + return f""" + HW_jk AS ( + SELECT + W_jk_a.{self.j}, + W_jk_a.{self.k}, + W_jk_a.{self.w} * H_j_h.{self.w} AS {self.w} FROM - W_jk, H_j, ABH + W_jk_a, H_j_h WHERE - W_jk.{self.j} = H_j.{self.j} + W_jk_a.{self.j} = H_j_h.{self.j} ) """ - def _sql_U_nk(self): + def _sql_X_nj_a(self): return f""" - U_nk AS ( + X_nj_a AS ( + SELECT + X_nj.{self.n}, + X_nj.{self.j}, + {self.POW}(X_nj.{self.w}, ABH.a) AS {self.w} + FROM + X_nj, ABH + ) + """ + + def _sql_HWX_nk(self): + return f""" + HWX_nk AS ( SELECT - X_nj.{self.n}, + X_nj_a.{self.n}, HW_jk.{self.k}, - {self.POW}( - {self.SUM}(HW_jk.{self.w} * {self.POW}(X_nj.{self.w}, ABH.a)), - 1 / ABH.a - ) AS {self.w} + {self.SUM}(HW_jk.{self.w} * X_nj_a.{self.w}) AS {self.w} FROM - X_nj, HW_jk, ABH + X_nj_a, HW_jk WHERE - X_nj.{self.j} = HW_jk.{self.j} + X_nj_a.{self.j} = HW_jk.{self.j} GROUP BY - {self.n}, {self.k}, ABH.a + X_nj_a.{self.n}, + HW_jk.{self.k} ) """ - def _sql_U_n(self): + def _sql_U_nk(self): return f""" - U_n AS ( + U_nk AS ( SELECT - {self.n}, - {self.SUM}({self.w}) AS {self.w} + HWX_nk.{self.n}, + HWX_nk.{self.k}, + {self.POW}(HWX_nk.{self.w}, 1 / ABH.a) AS {self.w} FROM - U_nk - GROUP BY - {self.n} + HWX_nk, ABH ) """ - - def _sql_R_nk(self): + + def _sql_U_n(self): return f""" - R_nk AS ( + U_n AS ( SELECT {self.n}, - {self.k}, - ROW_NUMBER() OVER(PARTITION BY {self.n} ORDER BY {self.w} DESC) AS {self.w} + {self.SUM}({self.w}) AS {self.w} FROM U_nk + GROUP BY + {self.n} ) """