Skip to content

Commit

Permalink
Merge pull request #2 from bigdata-ustc/fix-link-vars
Browse files Browse the repository at this point in the history
[BUGFIX] invalid link formulas
  • Loading branch information
tswsxk authored May 22, 2021
2 parents 15e457e + b84ffab commit cc2fb0d
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 79 deletions.
204 changes: 159 additions & 45 deletions EduNLP/Formula/Formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,84 @@
# 2021/3/8 @ tongshiwei

from pprint import pformat
from typing import List
from typing import List, Dict
import networkx as nx
from copy import deepcopy

from .ast import str2ast, get_edges, ast, link_variable
from .ast import str2ast, get_edges, link_variable

CONST_MATHORD = {r"\pi"}

__all__ = ["Formula", "FormulaGroup", "CONST_MATHORD"]
__all__ = ["Formula", "FormulaGroup", "CONST_MATHORD", "link_formulas"]


class Formula(object):
def __init__(self, formula, is_str=True, variable_standardization=False, const_mathord=None):
self._formula = formula
self._ast = str2ast(formula) if is_str else formula
if variable_standardization:
const_mathord = CONST_MATHORD if const_mathord is None else const_mathord
self.variable_standardization(inplace=True, const_mathord=const_mathord)
"""
Examples
--------
>>> f = Formula("x")
>>> f
<Formula: x>
>>> f.ast
[{'val': {'id': 0, 'type': 'mathord', 'text': 'x', 'role': None}, \
'structure': {'bro': [None, None], 'child': None, 'father': None, 'forest': None}}]
>>> f.elements
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None}]
>>> f.variable_standardization(inplace=True)
<Formula: x>
>>> f.elements
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None, 'var': 0}]
"""

def __init__(self, formula: (str, List[Dict]), variable_standardization=False, const_mathord=None,
*args, **kwargs):
"""
def variable_standardization(self, inplace=False, const_mathord=None):
Parameters
----------
formula: str or List[Dict]
latex formula string or the parsed abstracted syntax tree
variable_standardization
const_mathord
args
kwargs
"""
self._formula = formula
self._ast = None
self.reset_ast(
formula_ensure_str=False,
variable_standardization=variable_standardization,
const_mathord=const_mathord, *args, **kwargs
)

def variable_standardization(self, inplace=False, const_mathord=None, variable_connect_dict=None):
const_mathord = const_mathord if const_mathord is not None else CONST_MATHORD
ast_tree = self._ast if inplace else deepcopy(self._ast)
variables = {}
index = 0
var_code = variable_connect_dict["var_code"] if variable_connect_dict is not None else {}
for node in ast_tree:
if node["val"]["type"] == "mathord":
var = node["val"]["text"]
if var in const_mathord:
continue
else:
if var not in variables:
variables[var], index = index, index + 1
node["val"]["var"] = variables[var]
if var not in var_code:
var_code[var] = len(var_code)
node["val"]["var"] = var_code[var]
if inplace:
return self
else:
return Formula(ast_tree, is_str=False)

@property
def element(self):
def ast(self):
return self._ast

@property
def ast(self) -> (nx.Graph, nx.DiGraph):
def elements(self):
return [self.ast_graph.nodes[node] for node in self.ast_graph.nodes]

@property
def ast_graph(self) -> (nx.Graph, nx.DiGraph):
edges = [(edge[0], edge[1]) for edge in get_edges(self._ast) if edge[2] == 3]
tree = nx.DiGraph()
for node in self._ast:
Expand All @@ -60,38 +94,85 @@ def to_str(self):
return pformat(self._ast)

def __repr__(self):
return "<Formula: %s>" % self._formula
if isinstance(self._formula, str):
return "<Formula: %s>" % self._formula
else:
return super(Formula, self).__repr__()

def reset_ast(self, formula_ensure_str: bool = True, variable_standardization=False, const_mathord=None, *args,
**kwargs):
if formula_ensure_str is True and self.resetable is False:
raise TypeError("formula must be str, now is %s" % type(self._formula))
self._ast = str2ast(self._formula, *args, **kwargs) if isinstance(self._formula, str) else self._formula
if variable_standardization:
const_mathord = CONST_MATHORD if const_mathord is None else const_mathord
self.variable_standardization(inplace=True, const_mathord=const_mathord)
return self._ast

class FormulaGroup(object):
def __init__(self, formula_list: List[str], variable_standardization=False, const_mathord=None):
"""
@property
def resetable(self):
return isinstance(self._formula, str)

Parameters
----------
formula_list: List[str]
"""
forest_begin = 0

class FormulaGroup(object):
"""
Examples
---------
>>> fg = FormulaGroup(["x + y", "y + x", "z + x"])
>>> fg
<FormulaGroup: <Formula: x + y>;<Formula: y + x>;<Formula: z + x>>
>>> fg = FormulaGroup(["x + y", Formula("y + x"), "z + x"])
>>> fg
<FormulaGroup: <Formula: x + y>;<Formula: y + x>;<Formula: z + x>>
>>> fg = FormulaGroup(["x", Formula("y"), "x"])
>>> fg.elements
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None}, {'id': 1, 'type': 'mathord', 'text': 'y', 'role': None},\
{'id': 2, 'type': 'mathord', 'text': 'x', 'role': None}]
>>> fg = FormulaGroup(["x", Formula("y"), "x"], variable_standardization=True)
>>> fg.elements
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None, 'var': 0}, \
{'id': 1, 'type': 'mathord', 'text': 'y', 'role': None, 'var': 1}, \
{'id': 2, 'type': 'mathord', 'text': 'x', 'role': None, 'var': 0}]
"""

def __init__(self,
formula_list: (list, List[str], List[Formula]),
variable_standardization=False,
const_mathord=None,
detach=True
):
forest = []
formula_sep_index = []
for index in range(0, len(formula_list)):
formula_sep_index.append(forest_begin)
tree = ast(
formula_list[index],
forest_begin=forest_begin,
is_str=True
)
forest_begin += len(tree)
self._formulas = []
for formula in formula_list:
if isinstance(formula, str):
formula = Formula(
formula,
forest_begin=len(forest),
)
self._formulas.append(formula)
tree = formula.ast
elif isinstance(formula, Formula):
if detach:
formula = deepcopy(formula)
tree = formula.reset_ast(
formula_ensure_str=True,
variable_standardization=False,
forest_begin=len(forest),
)
self._formulas.append(formula)
else:
raise TypeError(
"the element in formula_list should be either str or Formula, now is %s" % type(Formula)
)
forest += tree
else:
formula_sep_index.append(len(forest))
forest = link_variable(forest)
variable_connect_dict = link_variable(forest)
self._forest = forest
self._formulas = []
for i, sep in enumerate(formula_sep_index[:-1]):
self._formulas.append(Formula(forest[sep: formula_sep_index[i + 1]], is_str=False))
if variable_standardization:
self.variable_standardization(inplace=True, const_mathord=const_mathord)
self.variable_standardization(
inplace=True,
const_mathord=const_mathord,
variable_connect_dict=variable_connect_dict
)

def __iter__(self):
return iter(self._formulas)
Expand All @@ -102,14 +183,47 @@ def __getitem__(self, item) -> Formula:
def __contains__(self, item) -> bool:
return item in self._formulas

def variable_standardization(self, inplace=False, const_mathord=None):
def variable_standardization(self, inplace=False, const_mathord=None, variable_connect_dict=None):
ret = []
for formula in self._formulas:
ret.append(formula.variable_standardization(inplace=inplace, const_mathord=const_mathord))
ret.append(formula.variable_standardization(inplace=inplace, const_mathord=const_mathord,
variable_connect_dict=variable_connect_dict))
return ret

def to_str(self):
return pformat(self._formulas)
return pformat(self._forest)

def __repr__(self):
return "<FormulaGroup: %s>" % ";".join([repr(_formula) for _formula in self._formulas])

@property
def ast(self):
return self._forest

@property
def elements(self):
return [self.ast_graph.nodes[node] for node in self.ast_graph.nodes]

@property
def ast_graph(self) -> (nx.Graph, nx.DiGraph):
edges = [(edge[0], edge[1]) for edge in get_edges(self._forest) if edge[2] == 3]
tree = nx.DiGraph()
for node in self._forest:
tree.add_node(
node["val"]["id"],
**node["val"]
)
tree.add_edges_from(edges)
return tree


def link_formulas(*formula: Formula, **kwargs):
forest = []
for form in formula:
forest += form.reset_ast(
forest_begin=len(forest),
**kwargs
)
variable_connect_dict = link_variable(forest)
for form in formula:
form.variable_standardization(inplace=True, variable_connect_dict=variable_connect_dict, **kwargs)
2 changes: 1 addition & 1 deletion EduNLP/Formula/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .Formula import Formula, FormulaGroup
from .Formula import Formula, FormulaGroup, link_formulas
from .ast import link_variable
from .Formula import CONST_MATHORD
13 changes: 8 additions & 5 deletions EduNLP/Formula/ast/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
__all__ = ["str2ast", "get_edges", "ast", "link_variable"]


def str2ast(formula: str):
return ast(formula, is_str=True)
def str2ast(formula: str, *args, **kwargs):
return ast(formula, is_str=True, *args, **kwargs)


def ast(formula: (str, List[Dict]), index=0, forest_begin=0, father_tree=None, is_str=False):
Expand All @@ -18,7 +18,7 @@ def ast(formula: (str, List[Dict]), index=0, forest_begin=0, father_tree=None, i
Parameters
----------
formula: str or List[Dict]
公式字符串或通过katex解析得到的结构体
index: int
本子树在树上的位置
forest_begin: int
Expand Down Expand Up @@ -224,8 +224,11 @@ def link_variable(forest):
l_v = [] + v
index = l_v.pop(i)
forest[index]['structure']['forest'] = l_v

return forest
variable_connect_dict = {
"var2id": forest_connect_dict,
"var_code": {}
}
return variable_connect_dict


def get_edges(forest):
Expand Down
1 change: 1 addition & 0 deletions EduNLP/SIF/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
# 2021/5/16 @ tongshiwei

from .sif import is_sif, to_sif, sif4sci
from .tokenization import link_formulas
20 changes: 19 additions & 1 deletion EduNLP/SIF/sif.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# 2021/5/16 @ tongshiwei

from .segment import seg
from .tokenization import tokenize
from .tokenization import tokenize, link_formulas


def is_sif(item):
Expand Down Expand Up @@ -60,6 +60,24 @@ def sif4sci(item: str, figures: (dict, bool) = None, safe=True, symbol: str = No
>>> sif4sci(test_item, symbol="gm",
... tokenization_params={"formula_params": {"method": "ast", "return_type": "list"}})
['如图所示', '\\\\bigtriangleup', 'A', 'B', 'C', '面积', '[MARK]', '[FIGURE]']
>>> test_item_1 = {
... "stem": r"若$x=2$, $y=\\sqrt{x}$,则下列说法正确的是$\\SIFChoice$",
... "options": [r"$x < y$", r"$y = x$", r"$y < x$"]
... }
>>> tls = [
... sif4sci(e, symbol="gm",
... tokenization_params={
... "formula_params": {
... "method": "ast", "return_type": "list", "ord2token": True, "var_numbering": True,
... "link_variable": False}
... })
... for e in ([test_item_1["stem"]] + test_item_1["options"])
... ]
>>> tls[1:]
[['mathord_0', '<', 'mathord_1'], ['mathord_0', '=', 'mathord_1'], ['mathord_0', '<', 'mathord_1']]
>>> link_formulas(*tls)
>>> tls[1:]
[['mathord_0', '<', 'mathord_1'], ['mathord_1', '=', 'mathord_0'], ['mathord_1', '<', 'mathord_0']]
"""
if safe is True and is_sif(item) is not True:
item = to_sif(item)
Expand Down
2 changes: 1 addition & 1 deletion EduNLP/SIF/tokenization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# coding: utf-8
# 2021/5/18 @ tongshiwei

from .tokenization import tokenize
from .tokenization import tokenize, link_formulas
4 changes: 2 additions & 2 deletions EduNLP/SIF/tokenization/formula/ast_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def ast_tokenize(formula, ord2token=False, var_numbering=False, return_type="for
<Formula: {x + y}^\\frac{\\pi}{2} + 1 = x>
"""
if return_type == "list":
ast = Formula(formula, variable_standardization=True).ast
ast = Formula(formula, variable_standardization=True).ast_graph
return traversal_formula(ast, ord2token=ord2token, var_numbering=var_numbering)
elif return_type == "formula":
return Formula(formula)
elif return_type == "ast":
return Formula(formula).ast
return Formula(formula).ast_graph
else:
raise ValueError()

Expand Down
Loading

0 comments on commit cc2fb0d

Please sign in to comment.