diff --git a/README.md b/README.md index 03d74be..c7e4d9e 100644 --- a/README.md +++ b/README.md @@ -155,13 +155,13 @@ Markdown(doc/LeiEq.md) | 表驱动编程(Table-Driven Programming) [简单 Python 基础] -> 简单 | -Markdown(doc/Continuation.md) | +简单 | +[Markdown](doc/Continuation.md) | 续延(Continuation) [简单 Python 基础] -> 中等 | -Markdown(doc/Algeff.md) | +中等 | +[Markdown](doc/Algeff.md) | 代数作用(Algebraic Effect) [简单 Python 基础,续延] diff --git a/doc/Algeff.md b/doc/Algeff.md new file mode 100644 index 0000000..f05aaa1 --- /dev/null +++ b/doc/Algeff.md @@ -0,0 +1,104 @@ +# 十分钟魔法练习:代数作用 + +### By 「玩火」改写 「penguin」 + +> 前置技能:Python基础,续延 + +## 可恢复异常 + +有时候我们希望在异常抛出后经过保存异常信息再跳回原来的地方继续执行。 + +显然Python默认异常处理无法直接实现这样的需求,因为在异常抛出时整个调用栈的信息全部被清除了。 + +但如果我们有了异常抛出时的续延那么可以同时抛出,在 `catch` 块中调用这个续延就能恢复之前的执行状态。 + +下面是实现可恢复异常的 `try-catch` : + +```python +from typing import Callable, List + +CE: List[Callable[[Exception, Callable], None]] = list() + + +def try_fun(body: Callable, handler: Callable[[Exception, Callable], None], cont: Callable): + CE.append(handler) + body(cont) + CE.pop() + + +def throw_fun(e: Exception, cont: Callable): + CE[-1](e, cont) +``` + +然后就可以像下面这样使用: + +```python +def try_run(t: int): + try_fun( + lambda cont: cont() if t != 0 else throw_fun(ArithmeticError("t==0"), cont), + lambda e, c: print(f"catch {e} and resumed") or c(), + lambda: print('final') + ) +``` + +而调用 `test(0)` 就会得到: + +``` +catch t==0 and resumed +final +``` + +## 代数作用 + +如果说在刚刚异常恢复的基础上希望在恢复时修补之前的异常错误就需要把之前的 `resume` 函数加上参数,这样修改以后它就成了代数作用(Algebaic Effect)的基础工具: + +```python +from typing import Callable, List + +CE: List[Callable[[Exception, Callable], None]] = list() + + +def try_fun(body: Callable, handler: Callable[[Exception, Callable[[int], None]], None], cont: Callable): + CE.append(handler) + body(cont) + CE.pop() + + +def throw_fun(e: Exception, cont: Callable): + CE[-1](e, cont) +``` + +使用方式如下: + +```python +def try_run(t: int): + try_fun( + lambda cont: cont() if t != 0 else throw_fun(ArithmeticError("t==0"), cont), + lambda e, c: print(f"catch {e} and resumed") or c(1), + lambda: print('final') + ) +``` + +而这个东西能实现不只是异常的功能,从某种程度上来说它能跨越函数发生作用(Perform Effect)。 + +比如说现在有个函数要记录日志,但是它并不关心如何记录日志,输出到标准流还是写入到文件或是上传到数据库。这时候它就可以调用 + +```python +perform(LogIt(INFO, "test"), ...) +``` + +来发生(Perform)一个记录日志的作用(Effect)然后再回到之前调用的位置继续执行,而具体这个作用产生了什么效果就由调用这个函数的人实现的 `try` 中的 `handler` 决定。这样发生作用和执行作用(Handle Effect)就解耦了。 + +进一步讲,发生作用和执行作用是可组合的。对于需要发生记录日志的作用,可以预先写一个输出到标准流的的执行器(Handler)一个输出到文件的执行器然后在调用函数的时候按需组合。这也就是它是代数的(Algebiac)的原因。 + +细心的读者还会发现这个东西还能跨函数传递数据,在需要某个量的时候调用 + +```python +perform(Ask("config"), ...) +``` + +就可以获得这个量而不用关心这个量是怎么来的,内存中来还是读取文件或者 HTTP 拉取。从而实现获取和使用的解耦。 + +而且这样的操作和状态单子非常非常像,实际上它就是和相比状态单子来说没有修改操作的读取器单子(Reader Monad)同构。 + +也就是说把执行器函数作为读取器单子的状态并在发生作用的时候执行对应函数就可以达到和用续延实现的代数作用相同的效果,反过来也同样可以模拟。 diff --git a/doc/Continuation.md b/doc/Continuation.md new file mode 100644 index 0000000..da3fcba --- /dev/null +++ b/doc/Continuation.md @@ -0,0 +1,132 @@ +# 十分钟魔法练习:续延 + +### By 「玩火」改写 「penguin」 + +> 前置技能:简单Python基础 + +## 续延 + +续延(Continuation)是指代表一个程序未来的函数,其参数是一个程序过去计算的结果。 + +比如对于这个程序: + +```python +def test(): + i = 1 # 1 + i += 1 # 2 + print(i) # 3 +``` + +它第二行以及之后的续延就是: + +```python +def cont(i: int): + i += 1 # 2 + print(i) # 3 +``` + +而第三行之后的续延是: + +```python +def cont(i: int): + print(i) # 3 +``` + +实际上可以把这整个程序的每一行改成一个续延然后用函数调用串起来变成和刚才的程序一样的东西: + +```python +def cont1(): + i = 1 + cont2(i) + +def cont2(i: int): + i += 1 + cont3(i) + +def cont3(i: int): + print(i) + +def test(): + cont1() +``` + +## 续延传递风格 + +续延传递风格(Continuation-Passing Style, CPS)是指把程序的续延作为函数的参数来获取函数返回值的编程思路。 + +听上去很难理解,把上面的三个 `cont` 函数改成CPS就很好理解了: + +```python +from typing import Callable + +def logic1(f: Callable[[int], None]): + i = 1 + f(i) + +def logic2(i: int, f: Callable[[int], None]): + i += 1 + f(i) + +def login3(i: int, f: Callable[[int], None]): + print(i) + f(i) + + +def test(): + logic1(lambda i2: logic2(i2, lambda i3: login3(i3, lambda _: None))) +``` + +每个 `logic` 函数的最后一个参数 `f` 就是整个程序的续延,而在每个函数的逻辑结束后整个程序的续延也就是未来会被调用。而 `test` 函数把整个程序组装起来。 + +小朋友,你有没有觉得最后的 `test` 函数写法超眼熟呢?实际上这个写法就是 Monad 的写法, Monad 的写法就是 CPS 。 + +另一个角度来说,这也是回调函数的写法,每个 `logic` 函数完成逻辑后调用了回调函数 `f` 来完成剩下的逻辑。实际上,异步回调思想很大程度上就是 CPS 。 + +## 有界续延 + +考虑有另一个函数 `callT` 调用了 `test` 函数,如: + +```python +def call_t(): + test() + print(3) +``` + +那么对于 `logic` 函数来说调用的 `f` 这个续延并不包括 `callT` 中的打印语句,那么实际上 `f` 这个续延并不是整个函数的未来而是 `test` 这个函数局部的未来。 + +这样代表局部程序的未来的函数就叫有界续延(Delimited Continuation)。 + +实际上在大多时候用的比较多的还是有界续延,因为在 Python 中获取整个程序的续延还是比较困难的,这需要全用 CPS 的写法。 + +## 异常 + +拿到了有界续延我们就能实现一大堆控制流魔法,这里拿异常处理举个例子,通过CPS写法自己实现一个 `try-throw` 。 + +首先最基本的想法是把每次调用 `try` 的 `catch` 函数保存起来,由于 `try` 可层层嵌套所以每次压入栈中,然后 `throw` 的时候将最近的 `catch` 函数取出来调用即可: + +```python +from typing import Callable, List + +CE: List[Callable[[Exception], None]] = list() + +def try_fun(body: Callable, handler: Callable[[Exception], None], cont: Callable): + CE.append(handler) + body(cont) + CE.pop() + +def throw_fun(e: Exception): + CE[-1](e) +``` + +这里 `body` 、 `Try` 、 `handler` 的最后一个参数都是这个程序的有界续延。 + +有了 `try-throw` 就可以按照CPS风格调用它们来达到处理异常的目的: + +```python +def try_run(t: int): + try_fun( + lambda cont: cont() if t != 0 else throw_fun(ArithmeticError("t==0")), + lambda e: print(f"catch {e}"), + lambda: print('final') + ) +``` diff --git a/magicpy/Algeff.py b/magicpy/Algeff.py new file mode 100644 index 0000000..4967073 --- /dev/null +++ b/magicpy/Algeff.py @@ -0,0 +1,21 @@ +from typing import Callable, List + +CE: List[Callable[[Exception, Callable], None]] = list() + + +def try_fun(body: Callable, handler: Callable[[Exception, Callable], None], cont: Callable): + CE.append(handler) + body(cont) + CE.pop() + + +def throw_fun(e: Exception, cont: Callable): + CE[-1](e, cont) + + +def try_run(t: int): + try_fun( + lambda cont: cont() if t != 0 else throw_fun(ArithmeticError("t==0"), cont), + lambda e, c: print(f"catch {e} and resumed") or c(), + lambda: print('final') + ) diff --git a/magicpy/Continuation.py b/magicpy/Continuation.py new file mode 100644 index 0000000..b8f406e --- /dev/null +++ b/magicpy/Continuation.py @@ -0,0 +1,51 @@ +from typing import Callable, List + + +def cont1(): + i = 1 + cont2(i) + + +def cont2(i: int): + i += 1 + cont3(i) + + +def cont3(i: int): + print(i) + + +def logic1(f: Callable[[int], None]): + i = 1 + f(i) + + +def logic2(i: int, f: Callable[[int], None]): + i += 1 + f(i) + + +def login3(i: int, f: Callable[[int], None]): + print(i) + f(i) + + +CE: List[Callable[[Exception], None]] = list() + + +def try_fun(body: Callable, handler: Callable[[Exception], None], cont: Callable): + CE.append(handler) + body(cont) + CE.pop() + + +def throw_fun(e: Exception): + CE[-1](e) + + +def try_run(t: int): + try_fun( + lambda cont: cont() if t != 0 else throw_fun(ArithmeticError("t==0")), + lambda e: print(f"catch {e}"), + lambda: print('final') + ) diff --git a/tests/test_magic.py b/tests/test_magic.py index b12c8ca..adaab12 100644 --- a/tests/test_magic.py +++ b/tests/test_magic.py @@ -99,6 +99,7 @@ def test_lambda_calculus(): ) assert str(expr) == "((λ x. (x (λ x. x))) y)" + def test_system_f(): from magicpy.STLC import NilEnv from magicpy.SystemF import Forall, Fun, App, Val, TVal, TArr, TForall, AppT @@ -122,3 +123,32 @@ def test_system_f(): )).gen_uuid() assert str(T.check_type(NilEnv())) == "(∀ a. (a -> (a -> a)))" assert str(IF.check_type(NilEnv())) == "(∀ a. ((∀ x. (x -> (x -> x))) -> (a -> (a -> a))))" + + +class TestContinuationAndAlgeff: + + def test_cont(self, capsys: CaptureFixture): + from magicpy.Continuation import cont1 + cont1() + out, err = capsys.readouterr() + assert "2\n" == out + + def test_login(self, capsys: CaptureFixture): + from magicpy.Continuation import logic1, logic2, login3 + logic1(lambda i2: logic2(i2, lambda i3: login3(i3, lambda _: None))) + out, err = capsys.readouterr() + assert "2\n" == out + + def test_try_run(self, capsys: CaptureFixture): + from magicpy.Continuation import try_run + try_run(1) + assert "final\n" == capsys.readouterr()[0] + try_run(0) + assert "catch t==0\n" == capsys.readouterr()[0] + + def test_resume(self, capsys: CaptureFixture): + from magicpy.Algeff import try_run + try_run(0) + out_lines = capsys.readouterr()[0].splitlines() + assert "catch t==0 and resumed" == out_lines[0] + assert "final" == out_lines[1]