From 215a54c8f406512774bbdfc9a13db391303e9691 Mon Sep 17 00:00:00 2001 From: jstoobysmith <72603918+jstoobysmith@users.noreply.github.com> Date: Wed, 5 Mar 2025 07:20:13 +0000 Subject: [PATCH] refactor: Major refactor of Elab for tensors --- .../Tensors/TensorSpecies/Basic.lean | 4 + PhysLean/Relativity/Tensors/Tree/Elab.lean | 527 +++++++----------- scripts/MetaPrograms/notes.lean | 2 +- 3 files changed, 218 insertions(+), 315 deletions(-) diff --git a/PhysLean/Relativity/Tensors/TensorSpecies/Basic.lean b/PhysLean/Relativity/Tensors/TensorSpecies/Basic.lean index 9106114a..621c9a2d 100644 --- a/PhysLean/Relativity/Tensors/TensorSpecies/Basic.lean +++ b/PhysLean/Relativity/Tensors/TensorSpecies/Basic.lean @@ -338,6 +338,10 @@ def liftTensor {n : ℕ} {c : Fin n → S.C} {E : Type} [AddCommMonoid E] [Modul (S.F.obj (OverColor.mk c) →ₗ[S.k] E) := PiTensorProduct.lift +/-- The number of indices `n` from a tensor. -/ +@[nolint unusedArguments] +def numIndices {S : TensorSpecies} {n : ℕ} {c : Fin n → S.C} (_ : S.F.obj (OverColor.mk c)) : ℕ := n + end TensorSpecies end diff --git a/PhysLean/Relativity/Tensors/Tree/Elab.lean b/PhysLean/Relativity/Tensors/Tree/Elab.lean index 689773f0..84962453 100644 --- a/PhysLean/Relativity/Tensors/Tree/Elab.lean +++ b/PhysLean/Relativity/Tensors/Tree/Elab.lean @@ -47,21 +47,13 @@ import PhysLean.Relativity.Lorentz.ComplexTensor.Basic this information. -/ -open Lean -open Lean.Elab.Term - -open Lean -open Lean.Meta -open Lean.Elab -open Lean.Elab.Term -open Lean Meta Elab Tactic -open IndexNotation -open complexLorentzTensor +open Lean Meta Elab Tactic Term IndexNotation + namespace TensorTree /-! -## Indexies +## Indices -/ @@ -99,7 +91,7 @@ def indexToIdent (stx : Syntax) : TermElabM Ident := | `(indexExpr|$a:ident) => return a | `(indexExpr| τ($a:ident)) => return a | _ => - throwError "Unsupported tensor expression syntax in indexToIdent: {stx}" + throwError "Unsupported expression syntax in indexToIdent: {stx}" /-- Takes a pair ``a b : ℕ × TSyntax `indexExpr``. If `a.1 < b.1` and `a.2 = b.2` then outputs `some (a.1, b.1)`, otherwise `none`. -/ @@ -119,171 +111,10 @@ def indexToDual (stx : Syntax) : Bool := /-! -## Tensor expressions - --/ - -/-- A syntax category for tensor expressions. -/ -declare_syntax_cat tensorExpr - -/-- The syntax for a tensor node. -/ -syntax term "|" (ppSpace indexExpr)* : tensorExpr - -/-- Equality. -/ -syntax:40 tensorExpr "=" tensorExpr:41 : tensorExpr - -/-- The syntax for tensor prod two tensor nodes. -/ -syntax:70 tensorExpr "⊗" tensorExpr:71 : tensorExpr - -/-- The syntax for tensor addition. -/ -syntax tensorExpr "+" tensorExpr : tensorExpr - -/-- Allowing brackets to be used in a tensor expression. -/ -syntax "(" tensorExpr ")" : tensorExpr - -/-- Scalar multiplication for tensors. -/ -syntax term "•ₜ" tensorExpr : tensorExpr - -/-- group action for tensors. -/ -syntax term "•ₐ" tensorExpr : tensorExpr - -/-- Negation of a tensor tree. -/ -syntax "-" tensorExpr : tensorExpr - -namespace TensorNode - -/-! - -## For tensor nodes. - -The operations are done in the following order: -- evaluation. -- dualization. -- contraction. - -We also want to ensure the number of indices is correct. +## Manipulation of lists of indexExpr -/ -/-- The indices of a tensor node. Before contraction, and evaluation. -/ -partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do - match stx with - | `(tensorExpr| $_:term | $[$args]*) => do - let indices ← args.toList.mapM fun arg => do - match arg with - | `(indexExpr|$t:indexExpr) => pure t - return indices - | _ => - throwError "Unsupported tensor expression syntax in getIndicesNode: {stx}" - -/-- Uses the structure of the tensor to get the number of indices. -/ -def getNoIndicesExact (stx : Syntax) : TermElabM ℕ := do - let expr ← elabTerm stx none - let type ← inferType expr - let strType := toString type - let n := (String.splitOn strType "CategoryTheory.MonoidalCategoryStruct.tensorObj").length - match n with - | 1 => - match type with - | Expr.app _ (Expr.app _ (Expr.app _ (Expr.app _ c))) => - let typeC ← inferType c - match typeC with - | Expr.forallE _ (Expr.app _ a) _ _ => - let a' ← whnf a - match a' with - | Expr.lit (Literal.natVal n) => return n - |_ => throwError s!"Could not extract number of indices from tensor - {stx} (getNoIndicesExact). " - | _ => throwError s!"Could not extract number of indices from tensor - {stx} (getNoIndicesExact). " - | _ => return 1 - | k => return k - -/-- The construction of an expression corresponding to the type of a given string once parsed. -/ -def stringToType (str : String) : TermElabM (Option Expr) := do - let env ← getEnv - let stx := Parser.runParserCategory env `term str - match stx with - | Except.error _ => return none - | Except.ok stx => return (some (← elabTerm stx none)) - -/-- The construction of an expression corresponding to the type of a given string once parsed. -/ -def stringToTerm (str : String) : TermElabM Term := do - let env ← getEnv - let stx := Parser.runParserCategory env `term str - match stx with - | Except.error _ => throwError "Could not create type from string (stringToTerm). " - | Except.ok stx => - match stx with - | `(term| $e) => return e - -/-- Specific types of tensors which appear which we want to elaborate in specific ways. -/ -def specialTypes : List (String × (Term → Term)) := [ - ("CoeSort.coe Lorentz.complexCo", fun T => - Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.down, T]), - ("CoeSort.coe Lorentz.complexContr", fun T => - Syntax.mkApp (mkIdent ``TensorTree.vecNodeE) #[mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.up, T]), - ("ModuleCat.carrier (Lorentz.complexContr ⊗ Lorentz.complexCo).V", fun T => - Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.up, mkIdent ``complexLorentzTensor.Color.down, T]), - ("ModuleCat.carrier (Lorentz.complexContr ⊗ Lorentz.complexContr).V", fun T => - Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.up, mkIdent ``complexLorentzTensor.Color.up, T]), - ("ModuleCat.carrier (Lorentz.complexCo ⊗ Lorentz.complexCo).V", fun T => - Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.down, mkIdent ``complexLorentzTensor.Color.down, T]), - ("ModuleCat.carrier (Lorentz.complexCo ⊗ Lorentz.complexContr).V", fun T => - Syntax.mkApp (mkIdent ``TensorTree.twoNodeE) #[ - mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.down, - mkIdent ``complexLorentzTensor.Color.up, T]), - ("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Lorentz.complexCo ⊗ Lorentz.complexCo", fun T => - Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[ - mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.down, - mkIdent ``complexLorentzTensor.Color.down, T]), - ("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Lorentz.complexContr ⊗ Lorentz.complexContr", fun T => - Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[ - mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.up, - mkIdent ``complexLorentzTensor.Color.up, T]), - ("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Lorentz.complexContr ⊗ Fermion.leftHanded ⊗ Fermion.rightHanded", fun T => - Syntax.mkApp (mkIdent ``TensorTree.constThreeNodeE) #[ - mkIdent ``complexLorentzTensor, mkIdent ``complexLorentzTensor.Color.up, - mkIdent ``complexLorentzTensor.Color.upL, - mkIdent ``complexLorentzTensor.Color.upR, T]), - ("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Fermion.leftHanded ⊗ Fermion.leftHanded", fun T => - Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[ - mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.upL, - mkIdent ``complexLorentzTensor.Color.upL, T]), - ("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Fermion.altLeftHanded ⊗ Fermion.altLeftHanded", fun T => - Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[ - mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.downL, - mkIdent ``complexLorentzTensor.Color.downL, T]), - ("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Fermion.altRightHanded ⊗ Fermion.altRightHanded", fun T => - Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[ - mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.downR, - mkIdent ``complexLorentzTensor.Color.downR, T]), - ("𝟙_ (Rep ℂ SL(2, ℂ)) ⟶ Fermion.rightHanded ⊗ Fermion.rightHanded", fun T => - Syntax.mkApp (mkIdent ``TensorTree.constTwoNodeE) #[ - mkIdent ``complexLorentzTensor, - mkIdent ``complexLorentzTensor.Color.upR, - mkIdent ``complexLorentzTensor.Color.upR, T])] - -/-- The syntax associated with a terminal node of a tensor tree. -/ -def termNodeSyntax (T : Term) : TermElabM Term := do - let expr ← elabTerm T none - let type ← inferType expr - match type with - | Expr.app _ (Expr.app _ (Expr.app _ _)) => - return Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T] - | _ => return Syntax.mkApp (mkIdent ``TensorTree.vecNode) #[T] - /-- Adjusts a list `List ℕ` by subtracting from each natural number the number of elements before it in the list which are less than itself. This is used to form a list of pairs which can be used for evaluating indices. -/ @@ -294,23 +125,25 @@ def evalAdjustPos (l : List ℕ) : List ℕ := (x :: prev, x - e)) l.reverse [] l'.2.reverse -/-- The positions in getIndicesNode which get evaluated, and the value they take. -/ -partial def getEvalPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do - let ind ← getIndices stx +/-- For list of `indexExpr` e.g. `[α, 3, β, 2, γ]`, `getEvalPos` + returns a list of pairs `ℕ × ℕ` related to indices which are numbers. + The second element of each pair is the number corresponding to that index. + The first element is the position of that number in the list of indices when + all other numbered indices before it are removed. Thus for the example given + `getEvalPos` outputs `[(1, 3), (2, 2)]`. -/ +def getEvalPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × ℕ)) := do let indEnum := ind.zipIdx let evals := indEnum.filter (fun x => indexExprIsNum x.1) let evals2 ← (evals.mapM (fun x => indexToNum x.1)) let pos := evalAdjustPos (evals.map (fun x => x.2)) return List.zip pos evals2 -/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.eval` to the given term. -/ -def evalSyntax (l : List (ℕ × ℕ)) (T : Term) : Term := - l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``TensorTree.eval) - #[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T - -/-- The pairs of positions in getIndicesNode which get contracted. -/ -partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do - let ind ← getIndices stx +/-- For list of `indexExpr` e.g. `[α, 3, β, α, 2, γ]`, `getContrPos` + first removes all indices which are numbers (e.g. `[α, β, α, γ]`). + It then outputs pairs `(a, b)` in `ℕ × ℕ` of positions of this list with `a < b` + such that the index at `a` is equal to the index at `b`. It checkes whether or not + an element is contracted more then once. -/ +def getContrPos (ind : List (TSyntax `indexExpr)) : TermElabM (List (ℕ × ℕ)) := do let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x) let indEnum := indFilt.zipIdx let bind := List.flatMap (fun a => indEnum.map (fun b => (a, b))) indEnum @@ -320,12 +153,10 @@ partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do return filt /-- The list of indices after contraction or evaluation. -/ -def withoutContr (ind : List (TSyntax `indexExpr)) : TermElabM (List (TSyntax `indexExpr)) := do +def withoutContrEval (ind : List (TSyntax `indexExpr)) : TermElabM (List (TSyntax `indexExpr)) := do let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x) return indFilt.filter (fun x => indFilt.count x ≤ 1) -end TensorNode - /-- Takes a list and puts consecutive elements into pairs. e.g. [0, 1, 2, 3] becomes [(0, 1), (2, 3)]. -/ def toPairs (l : List ℕ) : List (ℕ × ℕ) := @@ -345,64 +176,14 @@ def contrListAdjust (l : List (ℕ × ℕ)) : List (ℕ × ℕ) := (x :: prev, x - e)) l'.reverse [] toPairs l''.2.reverse -/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/ -def contrSyntax (l : List (ℕ × ℕ)) (T : Term) : Term := - (contrListAdjust l).foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr) - #[Syntax.mkNumLit (toString x0), - Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T - -namespace ProdNode - /-! -## For product nodes. - -For a product node we can take the tensor product, and then contract the indices. +## Permutations of indices -/ -/-- Gets the indices associated with a product node. -/ -partial def getIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do - match stx with - | `(tensorExpr| $_:term | $[$args]*) => do - return (← TensorNode.withoutContr (← TensorNode.getIndices stx)) - | `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do - let indicesA ← TensorNode.withoutContr (← getIndices a) - let indicesB ← TensorNode.withoutContr (← getIndices b) - return indicesA ++ indicesB - | `(tensorExpr| ($a:tensorExpr)) => do - return (← getIndices a) - | _ => - throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}" - -/-- The pairs of positions in getIndicesNode which get contracted. -/ -partial def getContrPos (stx : Syntax) : TermElabM (List (ℕ × ℕ)) := do - let ind ← getIndices stx - let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x) - let indEnum := indFilt.zipIdx - let bind := List.flatMap (fun a => indEnum.map (fun b => (a, b))) indEnum - let filt ← bind.filterMapM (fun x => indexPosEq x.1 x.2) - if ¬ ((filt.map Prod.fst).Nodup ∧ (filt.map Prod.snd).Nodup) then - throwError "To many contractions" - return filt - -/-- The list of indices after contraction. -/ -def withoutContr (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do - let ind ← getIndices stx - let indFilt : List (TSyntax `indexExpr) := ind.filter (fun x => ¬ indexExprIsNum x) - return ind.filter (fun x => indFilt.count x ≤ 1) - -/-- The syntax associated with a product of tensors. -/ -def prodSyntax (T1 T2 : Term) : Term := - Syntax.mkApp (mkIdent ``TensorTree.prod) #[T1, T2] - -end ProdNode - -/-! - -## Permutation constructions +open PhysLean.Fin --/ /-- Given two lists of indices returns the `List (ℕ)` representing the how one list permutes into the other. -/ def getPermutation (l1 l2 : List (TSyntax `indexExpr)) : TermElabM (List (ℕ)) := do @@ -413,34 +194,125 @@ def getPermutation (l1 l2 : List (TSyntax `indexExpr)) : TermElabM (List (ℕ)) (fun x => l1enum.find? (fun y => Lean.TSyntax.getId y.1 = Lean.TSyntax.getId x)) return l2''.map fun x => x.2 -open PhysLean.Fin +/-- The construction of an expression corresponding to the type of a given string once parsed. -/ +def stringToTerm (str : String) : TermElabM Term := do + let env ← getEnv + let stx := Parser.runParserCategory env `term str + match stx with + | Except.error _ => throwError "Could not create type from string (stringToTerm). " + | Except.ok stx => + match stx with + | `(term| $e) => return e /-- Given two lists of indices returns the permutation between them based on `finMapToEquiv`. -/ -def getPermutationSyntax (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term := do +def getPermutationTerm (l1 l2 : List (TSyntax `indexExpr)) : TermElabM Term := do let lPerm ← getPermutation l1 l2 let l2Perm ← getPermutation l2 l1 let permString := "![" ++ String.intercalate ", " (lPerm.map toString) ++ "]" let perm2String := "![" ++ String.intercalate ", " (l2Perm.map toString) ++ "]" - let P1 ← TensorNode.stringToTerm permString - let P2 ← TensorNode.stringToTerm perm2String + let P1 ← stringToTerm permString + let P2 ← stringToTerm perm2String let stx := Syntax.mkApp (mkIdent ``finMapToEquiv) #[P1, P2] return stx -namespace negNode +/-! -/-- The syntax associated with a product of tensors. -/ -def negSyntax (T1 : Term) : Term := - Syntax.mkApp (mkIdent ``TensorTree.neg) #[T1] +## Syntax for tensor expressions. + +-/ + +/-- A syntax category for tensor expressions. -/ +declare_syntax_cat tensorExpr + +/-- The syntax for a tensor node. -/ +syntax term "|" (ppSpace indexExpr)* : tensorExpr + +/-- Equality. -/ +syntax:40 tensorExpr "=" tensorExpr:41 : tensorExpr + +/-- The syntax for tensor prod two tensor nodes. -/ +syntax:70 tensorExpr "⊗" tensorExpr:71 : tensorExpr + +/-- The syntax for tensor addition. -/ +syntax tensorExpr "+" tensorExpr : tensorExpr + +/-- Allowing brackets to be used in a tensor expression. -/ +syntax "(" tensorExpr ")" : tensorExpr + +/-- Scalar multiplication for tensors. -/ +syntax term "•ₜ" tensorExpr : tensorExpr + +/-- group action for tensors. -/ +syntax term "•ₐ" tensorExpr : tensorExpr + +/-- Negation of a tensor tree. -/ +syntax "-" tensorExpr : tensorExpr + +/-! + +## Syntax of tensor expressions to indices. + +-/ -end negNode +/-- For syntax of the form `T` where `T` is `S.F.obj (OverColor.mk c)` this returns + the value of `TensorSpecies.numIndices T`. That is, the exact number of indices + associated with that tensor. -/ +def getNumIndicesExact (stx : Syntax) : TermElabM ℕ := do + match stx with + | `($t:term) => + let a ← elabTerm (← `(TensorSpecies.numIndices $t)) (some (mkConst ``Nat)) + let a' ← whnf a + match a' with + | Expr.lit (Literal.natVal n) => + return n + |_ => throwError s!"Could not extract number of indices from tensor + {stx} (getNoIndicesExact). " + +/-- For syntax of the form `T | α β 2 β`, `getAllIndices` returns a list `[α, β, 2, β]` + of all `indexExpr`. -/ +def getAllIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do + match stx with + | `(tensorExpr| $_:term | $[$args]*) => do + let indices ← args.toList.mapM fun arg => do + match arg with + | `(indexExpr|$t:indexExpr) => pure t + return indices + | _ => + throwError "Unsupported tensor expression syntax in getIndicesNode: {stx}" -/-- Returns the full list of indices after contraction. TODO: Include evaluation. -/ +/-- The function `getProdIndices` is defined for the following syntax: +1. For e.g. `T | α β 2 β`, it returns all uncontracted and unevaluated indices e.g.`[α]` +2. For e.g. `T1 | α β 2 β ⊗ T2 | α γ δ δ` it returns all unevaluated indices which + are not contracted in either tensor e.g. `[α, α, γ]`. +3. For e.g. `(T1 | α β 2 β ⊗ T2 | α γ δ δ) ⊗ T3 | γ` it does `2` recursively e.g. `[γ, γ]` +-/ +partial def getProdIndices (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do + match stx with + | `(tensorExpr| $_:term | $[$args]*) => do + return (← withoutContrEval (← getAllIndices stx)) + | `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do + let indicesA ← withoutContrEval (← getProdIndices a) + let indicesB ← withoutContrEval (← getProdIndices b) + return indicesA ++ indicesB + | `(tensorExpr| ($a:tensorExpr)) => do + return (← getProdIndices a) + | _ => + throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}" + +/-- Returns the remaining indices of a tensor expression after contraction and evaulation. + Thus every index in the output of `getIndicesFull` is ident and there are no duplicates. + Examples are: +1. `T | α β 2 β` gives `[α]` +2. `T1 | α β 2 β ⊗ T2 | α γ δ δ` gives `[γ]` +3. `(T1 | α β 2 β ⊗ T2 | α γ δ δ) ⊗ T3 | γ` gives `[]` +4. `T1 | α β 2 β + T2 | α 4 δ δ` gives `[α]` +-/ partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do match stx with | `(tensorExpr| $_:term | $[$args]*) => do - return (← TensorNode.withoutContr (← TensorNode.getIndices stx)) + return (← withoutContrEval (← getAllIndices stx)) | `(tensorExpr| $_:tensorExpr ⊗ $_:tensorExpr) => do - return (← ProdNode.withoutContr stx) + return (← withoutContrEval (← getProdIndices stx)) | `(tensorExpr| ($a:tensorExpr)) => do return (← getIndicesFull a) | `(tensorExpr| -$a:tensorExpr) => do @@ -450,60 +322,26 @@ partial def getIndicesFull (stx : Syntax) : TermElabM (List (TSyntax `indexExpr) | `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do return (← getIndicesFull a) | _ => - throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}" - -namespace SMul - -/-- The syntax associated with the scalar multiplication of tensors. -/ -def smulSyntax (c T : Term) : Term := - Syntax.mkApp (mkIdent ``TensorTree.smul) #[c, T] - -end SMul - -namespace Action - -/-- The syntax associated with the group action of tensors. -/ -def actionSyntax (c T : Term) : Term := - Syntax.mkApp (mkIdent ``TensorTree.action) #[c, T] - -end Action - -namespace Add + throwError "Unsupported tensor expression syntax in getIndicesFull: {stx}" /-- Gets the indices associated with the LHS of an addition. -/ -partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do +def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do match stx with | `(tensorExpr| $a:tensorExpr + $_:tensorExpr) => do return (← getIndicesFull a) | _ => - throwError "Unsupported tensor expression syntax in Add.getIndicesLeft: {stx}" + throwError "Unsupported tensor expression syntax in getIndicesLeft: {stx}" /-- Gets the indices associated with the RHS of an addition. -/ -partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do +def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do match stx with | `(tensorExpr| $_:tensorExpr + $a:tensorExpr) => do return (← getIndicesFull a) | _ => - throwError "Unsupported tensor expression syntax in Add.getIndicesRight: {stx}" - -/-- The syntax for a equality of tensor trees. -/ -def addSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do - let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax] - let RHS := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2] - return Syntax.mkApp (mkIdent ``add) #[T1, RHS] - -end Add - -namespace Equality - -/-! - -## For equality. - --/ + throwError "Unsupported tensor expression syntax in getIndicesRight: {stx}" /-- Gets the indices associated with the LHS of an equality. -/ -partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do +def getIndicesLeftEq (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do match stx with | `(tensorExpr| $a:tensorExpr = $_:tensorExpr) => do return (← getIndicesFull a) @@ -511,62 +349,123 @@ partial def getIndicesLeft (stx : Syntax) : TermElabM (List (TSyntax `indexExpr) throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}" /-- Gets the indices associated with the RHS of an equality. -/ -partial def getIndicesRight (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do +def getIndicesRightEq (stx : Syntax) : TermElabM (List (TSyntax `indexExpr)) := do match stx with | `(tensorExpr| $_:tensorExpr = $a:tensorExpr) => do return (← getIndicesFull a) | _ => throwError "Unsupported tensor expression syntax in getIndicesProd: {stx}" +/-! + +## Modifying terms to tensor trees + +-/ + +/-- For a term of the form `T` where `T` is `S.F.obj (OverColor.mk c)`, + `tensorTermToTensorTree` returns the term corresponding to the `tensorNode T` -/ +def nodeTermMap (T : Term) : Term := Syntax.mkApp (mkIdent ``TensorTree.tensorNode) #[T] + +/-- Given a list `l` of pairs `ℕ × ℕ` and a term `T` corresponding to a tensor tree, + for each `(a, b)` in `l`, `evalSyntax` applies `TensorTree.eval a b` to `T` recursively. + Here `a` is the position of the index to be evaluated and `b` is the value it is evaluated to. + + For example, if `l` is `[(1, 2), (1, 4)]` and `T` is a tensor tree then `evalSyntax l T` + is `TensorTree.eval 1 4 (TensorTree.eval 1 2 T)`. + + The list `l` is expected to be the output of `getEvalPos`. +-/ +def evalTermMap (l : List (ℕ × ℕ)) (T : Term) : Term := + l.foldl (fun T' (x1, x2) => Syntax.mkApp (mkIdent ``TensorTree.eval) + #[Syntax.mkNumLit (toString x1), Syntax.mkNumLit (toString x2), T']) T + +/-- For each element of `l : List (ℕ × ℕ)` applies `TensorTree.contr` to the given term. -/ +def contrTermMap (l : List (ℕ × ℕ)) (T : Term) : Term := + (contrListAdjust l).foldl (fun T' (x0, x1) => Syntax.mkApp (mkIdent ``TensorTree.contr) + #[Syntax.mkNumLit (toString x0), + Syntax.mkNumLit (toString x1), mkIdent ``rfl, T']) T + +/-- The syntax associated with a product of tensors. -/ +def prodTermMap (T1 T2 : Term) : Term := + Syntax.mkApp (mkIdent ``TensorTree.prod) #[T1, T2] + +/-- The syntax associated with a product of tensors. -/ +def negTermMap (T1 : Term) : Term := + Syntax.mkApp (mkIdent ``TensorTree.neg) #[T1] + +/-- The syntax associated with the scalar multiplication of tensors. -/ +def smulTermMap (c T : Term) : Term := + Syntax.mkApp (mkIdent ``TensorTree.smul) #[c, T] + +/-- The syntax associated with the group action of tensors. -/ +def actionTermMap (c T : Term) : Term := + Syntax.mkApp (mkIdent ``TensorTree.action) #[c, T] + /-- The syntax for a equality of tensor trees. -/ -def equalSyntax (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do +def addTermMap (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do + let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax] + let RHS := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2] + return Syntax.mkApp (mkIdent ``add) #[T1, RHS] + +/-- The syntax for a equality of tensor trees. -/ +def equalTermMap (permSyntax : Term) (T1 T2 : Term) : TermElabM Term := do let X1 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[T1] let P := Syntax.mkApp (mkIdent ``OverColor.equivToHomEq) #[permSyntax] let X2' := Syntax.mkApp (mkIdent ``TensorTree.perm) #[P, T2] let X2 := Syntax.mkApp (mkIdent ``TensorTree.tensor) #[X2'] return Syntax.mkApp (mkIdent ``Eq) #[X1, X2] -end Equality +/-! + +## Syntax to tensor tree -/-- Creates the syntax associated with a tensor node. -/ +-/ + +/-- The function `syntaxFull` -/ partial def syntaxFull (stx : Syntax) : TermElabM Term := do match stx with | `(tensorExpr| $T:term | $[$args]*) => - let indices ← TensorNode.getIndices stx - let rawIndex ← TensorNode.getNoIndicesExact T + let indices ← getAllIndices stx + let rawIndex ← getNumIndicesExact T if indices.length ≠ rawIndex then throwError "The expected number of indices {rawIndex} does not match the tensor {T}." - let tensorNodeSyntax ← TensorNode.termNodeSyntax T - let evalSyntax := TensorNode.evalSyntax (← TensorNode.getEvalPos stx) tensorNodeSyntax - let contrSyntax := contrSyntax (← TensorNode.getContrPos stx) evalSyntax + let tensorNodeSyntax := nodeTermMap T + let evalSyntax := evalTermMap (← getEvalPos indices) tensorNodeSyntax + let contrSyntax := contrTermMap (← getContrPos indices) evalSyntax return contrSyntax | `(tensorExpr| $a:tensorExpr ⊗ $b:tensorExpr) => do - let prodSyntax := ProdNode.prodSyntax (← syntaxFull a) (← syntaxFull b) - let contrSyntax := contrSyntax (← ProdNode.getContrPos stx) prodSyntax + let prodSyntax := prodTermMap (← syntaxFull a) (← syntaxFull b) + let contrSyntax := contrTermMap (← getContrPos (← getProdIndices stx)) prodSyntax return contrSyntax | `(tensorExpr| ($a:tensorExpr)) => do return (← syntaxFull a) | `(tensorExpr| -$a:tensorExpr) => do - return negNode.negSyntax (← syntaxFull a) + return negTermMap (← syntaxFull a) | `(tensorExpr| $c:term •ₜ $a:tensorExpr) => do - return SMul.smulSyntax c (← syntaxFull a) + return smulTermMap c (← syntaxFull a) | `(tensorExpr| $c:term •ₐ $a:tensorExpr) => do - return Action.actionSyntax c (← syntaxFull a) + return actionTermMap c (← syntaxFull a) | `(tensorExpr| $a + $b) => do - let indicesLeft ← Add.getIndicesLeft stx - let indicesRight ← Add.getIndicesRight stx - let permSyntax ← getPermutationSyntax indicesLeft indicesRight - let addSyntax ← Add.addSyntax permSyntax (← syntaxFull a) (← syntaxFull b) + let indicesLeft ← getIndicesLeft stx + let indicesRight ← getIndicesRight stx + let permSyntax ← getPermutationTerm indicesLeft indicesRight + let addSyntax ← addTermMap permSyntax (← syntaxFull a) (← syntaxFull b) return addSyntax | `(tensorExpr| $a:tensorExpr = $b:tensorExpr) => do - let indicesLeft ← Equality.getIndicesLeft stx - let indicesRight ← Equality.getIndicesRight stx - let permSyntax ← getPermutationSyntax indicesLeft indicesRight - let equalSyntax ← Equality.equalSyntax permSyntax (← syntaxFull a) (← syntaxFull b) + let indicesLeft ← getIndicesLeftEq stx + let indicesRight ← getIndicesRightEq stx + let permSyntax ← getPermutationTerm indicesLeft indicesRight + let equalSyntax ← equalTermMap permSyntax (← syntaxFull a) (← syntaxFull b) return equalSyntax | _ => throwError "Unsupported tensor expression syntax in elaborateTensorNode: {stx}" +/-! + +## Elaboration + +-/ + /-- An elaborator for tensor nodes. This is to be generalized. -/ def elaborateTensorNode (stx : Syntax) : TermElabM Expr := do let tensorExpr ← elabTerm (← syntaxFull stx) none diff --git a/scripts/MetaPrograms/notes.lean b/scripts/MetaPrograms/notes.lean index 75606a4d..77a342f5 100644 --- a/scripts/MetaPrograms/notes.lean +++ b/scripts/MetaPrograms/notes.lean @@ -344,7 +344,7 @@ def harmonicOscillator : Note where ] def higgsPotential : Note where - title := "The Higgs potential 🚧" + title := "The Higgs potential" curators := ["Joseph Tooby-Smith"] parts := [ .h1 "Introduction",