From 3b32a2951b91e8942f7b848b30228adea9d16d66 Mon Sep 17 00:00:00 2001 From: Simon Marlow Date: Fri, 19 Jul 2024 08:13:42 -0700 Subject: [PATCH] Split up Glean.Query.Typecheck Summary: Just moving stuff around to better organise things Reviewed By: malanka Differential Revision: D59812606 fbshipit-source-id: 40b29fceb77284410c3ca4fc710f9fe84df4fa01 --- glean.cabal.in | 2 + glean/db/Glean/Query/Typecheck.hs | 406 +----------------------- glean/db/Glean/Query/Typecheck/Monad.hs | 183 +++++++++++ glean/db/Glean/Query/Typecheck/Unify.hs | 304 ++++++++++++++++++ 4 files changed, 491 insertions(+), 404 deletions(-) create mode 100644 glean/db/Glean/Query/Typecheck/Monad.hs create mode 100644 glean/db/Glean/Query/Typecheck/Unify.hs diff --git a/glean.cabal.in b/glean.cabal.in index 6bd4246ec..cf5b17030 100644 --- a/glean.cabal.in +++ b/glean.cabal.in @@ -521,7 +521,9 @@ library db Glean.Query.Prune Glean.Query.Reorder Glean.Query.Typecheck + Glean.Query.Typecheck.Monad Glean.Query.Typecheck.Types + Glean.Query.Typecheck.Unify Glean.Query.Vars Glean.Query.JSON diff --git a/glean/db/Glean/Query/Typecheck.hs b/glean/db/Glean/Query/Typecheck.hs index cf21ecfbd..d4e8e90cf 100644 --- a/glean/db/Glean/Query/Typecheck.hs +++ b/glean/db/Glean/Query/Typecheck.hs @@ -34,11 +34,9 @@ import Data.Char import Data.Foldable (toList) import Data.List.Extra (firstJust) import qualified Data.IntMap as IntMap -import Data.IntMap (IntMap) import qualified Data.Map as Map import Data.Maybe import qualified Data.HashMap.Strict as HashMap -import Data.HashMap.Strict (HashMap) import qualified Data.HashSet as HashSet import Data.HashSet (HashSet) import Data.List @@ -57,36 +55,16 @@ import Glean.Display import Glean.Query.Codegen.Types (Match(..), Var(..), QueryWithInfo(..), Typed(..)) import Glean.Query.Typecheck.Types +import Glean.Query.Typecheck.Monad +import Glean.Query.Typecheck.Unify import Glean.RTS.Types as RTS import Glean.RTS.Term hiding (Tuple, ByteArray, String, Array, Nat, All) import qualified Glean.RTS.Term as RTS -import qualified Glean.Database.Config as Config import Glean.Database.Schema.Types import Glean.Schema.Util import Glean.Util.Some -data TcEnv = TcEnv - { tcEnvTypes :: HashMap TypeId TypeDetails - , tcEnvPredicates :: HashMap PredicateId PredicateDetails - } - -emptyTcEnv :: TcEnv -emptyTcEnv = TcEnv HashMap.empty HashMap.empty - -data TcOpts = TcOpts - { tcOptDebug :: !Config.DebugFlags - , tcOptAngleVersion :: !AngleVersion - } - -defaultTcOpts :: Config.DebugFlags -> AngleVersion -> TcOpts -defaultTcOpts debug v = TcOpts - { tcOptDebug = debug - , tcOptAngleVersion = v - } - -type ToRtsType = Schema.Type -> Maybe Type - type Pat' s = SourcePat_ s PredicateId TypeId type Statement' s = SourceStatement_ s PredicateId TypeId type Query' s = SourceQuery_ s PredicateId TypeId @@ -721,27 +699,6 @@ falseVal, trueVal :: TcPat falseVal = RTS.Alt 0 (RTS.Tuple []) trueVal = RTS.Alt 1 (RTS.Tuple []) --- Smart constructor for wildcard patterns; replaces a wildcard that --- matches the unit type with a concrete pattern. This is necessary --- when we have an enum in an expression position: we can't translate --- @nothing@ into @{ nothing = _ }@ because the wildcard would be --- illegal in an expression. -mkWild :: Type -> TcPat -mkWild ty - | RecordTy [] <- derefType ty = RTS.Tuple [] - | otherwise = RTS.Ref (MatchWild ty) - -inPat :: (IsSrcSpan s) => Pat' s -> T a -> T a -inPat pat = addErrSpan (sourcePatSpan pat) - -addErrSpan :: (IsSrcSpan s) => s -> T a -> T a -addErrSpan span act = do - act `catchError` \errDoc -> do - prettyError $ vcat - [ pretty span - , errDoc - ] - patTypeError :: (IsSrcSpan s) => Pat' s -> Type -> T a patTypeError = patTypeErrorDesc "type error in pattern" @@ -755,76 +712,6 @@ patTypeErrorDesc desc q ty = do , "expected type: " <> display opts ty ] -data TcMode = TcModeQuery | TcModePredicate - deriving Eq - -data TypecheckState = TypecheckState - { tcEnv :: TcEnv - , tcAngleVersion :: AngleVersion - , tcDebug :: !Bool - , tcRtsType :: ToRtsType - , tcNextVar :: {-# UNPACK #-} !Int - , tcNextTyVar :: {-# UNPACK #-} !Int - , tcScope :: HashMap Name Var - -- ^ Variables that we have types for, and have allocated a Var - , tcVisible :: HashSet Name - -- ^ Variables that are currently visible - , tcFree :: HashSet Name - -- ^ Variables that are mentioned only once - , tcUses :: HashSet Name - -- ^ Accumulates variables that appear in an ContextExpr context - , tcBindings :: HashSet Name - -- ^ Accumulates variables that appear in an ContextPat context - , tcMode :: TcMode - , tcDisplayOpts :: DisplayOpts - -- ^ Options for pretty-printing - , tcVars :: IntMap Var - , tcSubst :: IntMap Type - , tcPromote :: [(Type, Type, Some IsSrcSpan)] - } - -initialTypecheckState - :: TcEnv - -> TcOpts - -> ToRtsType - -> TcMode - -> TypecheckState -initialTypecheckState tcEnv TcOpts{..} rtsType mode = TypecheckState - { tcEnv = tcEnv - , tcAngleVersion = tcOptAngleVersion - , tcDebug = Config.tcDebug tcOptDebug - , tcRtsType = rtsType - , tcNextVar = 0 - , tcNextTyVar = 0 - , tcScope = HashMap.empty - , tcVisible = HashSet.empty - , tcFree = HashSet.empty - , tcUses = HashSet.empty - , tcBindings = HashSet.empty - , tcMode = mode - , tcDisplayOpts = defaultDisplayOpts - -- might make this configurable with flags later - , tcSubst = IntMap.empty - , tcVars = IntMap.empty - , tcPromote = [] - } - -type T a = StateT TypecheckState (ExceptT (Doc ()) IO) a - -whenDebug :: T () -> T () -whenDebug act = do - lvl <- gets tcDebug - when lvl act - -freshTyVarInt :: T Int -freshTyVarInt = do - TypecheckState{tcNextTyVar = n} <- get - modify $ \s -> s { tcNextTyVar = n+1 } - return n - -freshTyVar :: T Type -freshTyVar = TyVar <$> freshTyVarInt - bindOrUse :: Context -> Name -> TypecheckState -> TypecheckState bindOrUse ContextExpr name state = state { tcUses = HashSet.insert name (tcUses state) } @@ -907,18 +794,6 @@ checkVarCase span name pretty name | otherwise = return () -prettyError :: Doc () -> T a -prettyError = throwError - -prettyErrorIn :: IsSrcSpan s => Pat' s -> Doc () -> T a -prettyErrorIn pat doc = prettyErrorAt (sourcePatSpan pat) doc - -prettyErrorAt :: IsSrcSpan span => span -> Doc () -> T a -prettyErrorAt span doc = prettyError $ vcat - [ pretty span - , doc - ] - -- | Typechecking A|B -- -- 1. The set of variables that are considered to be *bound* by this @@ -1328,280 +1203,3 @@ resolvePromote = do loop resolved2 loop promotes - -unify :: Type -> Type -> T () -unify ByteTy ByteTy = return () -unify NatTy NatTy = return () -unify StringTy StringTy = return () -unify (ArrayTy t) (ArrayTy u) = unify t u -unify a@(RecordTy ts) b@(RecordTy us) - | length ts == length us = forM_ (zip ts us) $ - \(FieldDef f t, FieldDef g u) -> - if f == g || compareStructurally - then unify t u - else unifyError a b - where - isTuple = all (Text.isInfixOf "tuplefield" . fieldDefName) - compareStructurally = isTuple ts || isTuple us - -- structural equality for tuples by ignoring field names. -unify a@(SumTy ts) b@(SumTy us) - | length ts == length us = forM_ (zip ts us) $ - \(FieldDef f t, FieldDef g u) -> - if f /= g then unifyError a b else unify t u -unify (PredicateTy (PidRef p _)) (PredicateTy (PidRef q _)) - | p == q = return () -unify (NamedTy (ExpandedType n _)) (NamedTy (ExpandedType m _)) - | n == m = return () -unify (NamedTy (ExpandedType _ t)) u = unify t u -unify t (NamedTy (ExpandedType _ u)) = unify t u -unify (MaybeTy t) (MaybeTy u) = unify t u -unify (MaybeTy t) u@SumTy{} = unify (lowerMaybe t) u -unify t@SumTy{} (MaybeTy u) = unify t (lowerMaybe u) -unify (EnumeratedTy ns) (EnumeratedTy ms) | ns == ms = return () -unify (EnumeratedTy ns) u@SumTy{} = unify (lowerEnum ns) u -unify t@SumTy{} (EnumeratedTy ns) = unify t (lowerEnum ns) -unify BooleanTy BooleanTy = return () - -unify (TyVar x) (TyVar y) | x == y = return () -unify (TyVar x) t = extend x t -unify t (TyVar x) = extend x t - -unify (HasTy a ra x) (HasTy b rb y) = do - mapM_ (uncurry unify) $ Map.intersectionWith (,) a b - z <- freshTyVarInt - let all = HasTy (Map.union a b) (ra || rb) z - extend x all - extend y all - -unify a@(HasTy m _ x) b@(RecordTy fs) = do - forM_ fs $ \(FieldDef f ty) -> - case Map.lookup f m of - Nothing -> return () - Just ty' -> unify ty ty' - forM_ (Map.keys m) $ \n -> - when (n `notElem` map fieldDefName fs) $ - unifyError a b - extend x (RecordTy fs) - -unify a@(HasTy m False x) b@(SumTy fs) = do - forM_ fs $ \(FieldDef f ty) -> - case Map.lookup f m of - Nothing -> return () - Just ty' -> unify ty ty' - forM_ (Map.keys m) $ \n -> - when (n `notElem` map fieldDefName fs) $ - unifyError a b - extend x (SumTy fs) - -unify a@RecordTy{} b@HasTy{} = unify b a -unify a@SumTy{} b@HasTy{} = unify b a - -unify a b = unifyError a b - -unifyError :: Type -> Type -> T a -unifyError a b = do - opts <- gets tcDisplayOpts - prettyError $ vcat - [ "type error:" - , indent 2 (display opts a) - , "does not match:" - , indent 2 (display opts b) - ] - -extend :: Int -> Type -> T () -extend x t = do - t' <- apply t -- avoid creating a cycle in the substitution - if - | TyVar y <- t, y == x -> return () - | otherwise -> do - subst <- gets tcSubst - case IntMap.lookup x subst of - Just u -> unify t' u - Nothing -> - modify $ \s -> s{ tcSubst = IntMap.insert x t' (tcSubst s) } - -apply :: Type -> T Type -apply t = do - subst <- gets tcSubst - apply_ (\x -> return (IntMap.lookup x subst)) t - -zonkType :: Type -> T Type -zonkType t = do - subst <- gets tcSubst - opts <- gets tcDisplayOpts - let lookup x = case IntMap.lookup x subst of - Nothing -> prettyError $ "ambiguous type: " <> - display opts (TyVar x :: Type) - Just u -> return (Just u) - apply_ lookup t - -apply_ :: (Int -> T (Maybe Type)) -> Type -> T Type -apply_ lookup t = go t - where - go t = case t of - ByteTy -> return t - NatTy -> return t - StringTy -> return t - ArrayTy t -> ArrayTy <$> go t - RecordTy fs -> - fmap RecordTy $ forM fs $ \(FieldDef n t) -> - FieldDef n <$> go t - SumTy fs -> - fmap SumTy $ forM fs $ \(FieldDef n t) -> - FieldDef n <$> go t - PredicateTy{} -> return t - NamedTy{} -> return t - MaybeTy t -> MaybeTy <$> go t - EnumeratedTy{} -> return t - BooleanTy -> return t - TyVar x -> do - m <- lookup x - case m of - Nothing -> return t - Just u -> go u - HasTy _ _ x -> do - m <- lookup x - case m of - Nothing -> return t - Just u -> go u - SetTy t -> SetTy <$> go t - -zonkVars :: T () -zonkVars = do - vars <- gets tcVars - subst <- gets tcSubst - zonked <- forM vars $ \Var{..} -> do - let - lookup x = case IntMap.lookup x subst of - Nothing -> prettyError $ vcat - [ "variable " <> pretty var <> - " has unknown type" - , " try adding a type signature, like: " <> pretty var <> " : T" - ] - where var = fromMaybe (Text.pack ('_':show varId)) varOrigName - Just u -> return (Just u) - t <- apply_ lookup varType - return (Var { varType = t, ..}) - modify $ \s -> s { tcVars = zonked } - -zonkTcQuery :: TcQuery -> T TcQuery -zonkTcQuery (TcQuery ty k mv stmts) = - TcQuery - <$> zonkType ty - <*> zonkTcPat k - <*> mapM zonkTcPat mv - <*> mapM zonkTcStatement stmts - -zonkTcPat :: TcPat -> T TcPat -zonkTcPat p = case p of - RTS.Byte{} -> return p - RTS.Nat{} -> return p - RTS.Array ts -> RTS.Array <$> mapM zonkTcPat ts - RTS.ByteArray{} -> return p - RTS.Tuple ts -> RTS.Tuple <$> mapM zonkTcPat ts - RTS.Alt n t -> RTS.Alt n <$> zonkTcPat t - RTS.String{} -> return p - RTS.All ts -> RTS.All <$> mapM zonkTcPat ts - Ref (MatchExt (Typed ty (TcPromote inner e))) -> do - ty' <- zonkType ty - inner' <- zonkType inner - e' <- zonkTcPat e - case (ty', inner') of - (TyVar{}, _) -> error "zonkMatch: tyvar" - (_, TyVar{}) -> error "zonkMatch: tyvar" - (PredicateTy (PidRef _ ref), PredicateTy (PidRef _ ref')) - | ref == ref' -> return e' - (PredicateTy pidRef@(PidRef _ ref), _other) -> do - PredicateDetails{..} <- getPredicateDetails ref - let vpat = Ref (MatchWild predicateValueType) - return (Ref (MatchExt (Typed ty' - (TcFactGen pidRef e' vpat SeekOnAllFacts)))) - _ -> - return e' - Ref (MatchExt (Typed ty (TcStructPat fs))) -> do - ty' <- zonkType ty - case ty' of - RecordTy fields -> - fmap RTS.Tuple $ forM fields $ \(FieldDef f ty) -> - case [ p | (g,p) <- fs, f == g ] of - [] -> return (mkWild ty) - (x:_) -> zonkTcPat x - SumTy fields -> - case fs of - [(name,pat)] - | (_, n) :_ <- lookupField name fields -> do - pat' <- zonkTcPat pat - return (RTS.Alt n pat') - _other -> error $ "zonkTcPat: " <> show (displayDefault p) - _other -> do - opts <- gets tcDisplayOpts - prettyError $ - nest 4 $ vcat - [ "type error in pattern" - , "pattern: " <> display opts p - , "expected type: " <> display opts ty - ] - - Ref m -> Ref <$> zonkMatch m - -zonkMatch :: Match (Typed TcTerm) Var -> T (Match (Typed TcTerm) Var) -zonkMatch m = case m of - MatchWild ty -> MatchWild <$> zonkType ty - MatchNever ty -> MatchNever <$> zonkType ty - MatchFid{} -> return m - MatchBind v -> MatchBind <$> var v - MatchVar v -> MatchVar <$> var v - MatchAnd a b -> MatchAnd <$> zonkTcPat a <*> zonkTcPat b - MatchPrefix s pat -> MatchPrefix s <$> zonkTcPat pat - MatchArrayPrefix ty ts -> - MatchArrayPrefix <$> zonkType ty <*> mapM zonkTcPat ts - MatchExt (Typed ty e) -> - MatchExt <$> (Typed <$> zonkType ty <*> zonkTcTerm e) - where - var (Var _ n _) = do - vars <- gets tcVars - case IntMap.lookup n vars of - Nothing -> error "zonkMatch" - Just v -> return v - -zonkTcTerm :: TcTerm -> T TcTerm -zonkTcTerm t = case t of - TcOr a b -> TcOr <$> zonkTcPat a <*> zonkTcPat b - TcFactGen pid k v sec -> - TcFactGen pid <$> zonkTcPat k <*> zonkTcPat v <*> pure sec - TcElementsOfArray a -> TcElementsOfArray <$> zonkTcPat a - TcQueryGen q -> TcQueryGen <$> zonkTcQuery q - TcNegation stmts -> TcNegation <$> mapM zonkTcStatement stmts - TcPrimCall op args -> TcPrimCall op <$> mapM zonkTcPat args - TcIf (Typed ty cond) th el -> - TcIf - <$> (Typed <$> zonkType ty <*> zonkTcPat cond) - <*> zonkTcPat th - <*> zonkTcPat el - TcDeref ty valTy p -> - TcDeref <$> zonkType ty <*> zonkType valTy <*> zonkTcPat p - TcFieldSelect (Typed ty p) f -> - TcFieldSelect - <$> (Typed <$> zonkType ty <*> zonkTcPat p) - <*> pure f - TcAltSelect (Typed ty p) f -> - TcAltSelect - <$> (Typed <$> zonkType ty <*> zonkTcPat p) - <*> pure f - TcElements p -> TcElements <$> zonkTcPat p - TcPromote{} -> error "zonkTcTerm: TcPromote" -- handled in zonkTcPat - TcStructPat{} -> error "zonkTcTerm: TcStructPat" -- handled in zonkTcPat - -zonkTcStatement :: TcStatement -> T TcStatement -zonkTcStatement (TcStatement ty l r) = - TcStatement - <$> zonkType ty - <*> zonkTcPat l - <*> zonkTcPat r - -getPredicateDetails :: PredicateId -> T PredicateDetails -getPredicateDetails pred = do - TcEnv{..} <- gets tcEnv - case HashMap.lookup pred tcEnvPredicates of - Nothing -> error $ "predicateKeyTYpe: " <> show (displayDefault pred) - Just d -> return d diff --git a/glean/db/Glean/Query/Typecheck/Monad.hs b/glean/db/Glean/Query/Typecheck/Monad.hs new file mode 100644 index 000000000..50596b5c7 --- /dev/null +++ b/glean/db/Glean/Query/Typecheck/Monad.hs @@ -0,0 +1,183 @@ +{- + Copyright (c) Meta Platforms, Inc. and affiliates. + All rights reserved. + + This source code is licensed under the BSD-style license found in the + LICENSE file in the root directory of this source tree. +-} + +module Glean.Query.Typecheck.Monad ( + T, + ToRtsType, + TypecheckState(..), + TcOpts(..), + TcMode(..), + defaultTcOpts, + TcEnv(..), + emptyTcEnv, + initialTypecheckState, + whenDebug, + freshTyVar, + freshTyVarInt, + getPredicateDetails, + mkWild, + + -- * Errors + prettyError, + prettyErrorIn, + prettyErrorAt, + inPat, + addErrSpan, + ) where + +import Control.Monad.Except +import Control.Monad.State +import qualified Data.HashMap.Strict as HashMap +import Data.HashMap.Strict (HashMap) +import qualified Data.HashSet as HashSet +import Data.HashSet (HashSet) +import qualified Data.IntMap as IntMap +import Data.IntMap (IntMap) +import Data.Text.Prettyprint.Doc hiding ((<>), enclose) + +import Glean.Angle.Types hiding (Type) +import qualified Glean.Angle.Types as Schema +import qualified Glean.Database.Config as Config +import Glean.Database.Schema.Types +import Glean.Display +import Glean.Query.Typecheck.Types +import Glean.Query.Codegen.Types +import qualified Glean.RTS.Term as RTS +import Glean.RTS.Types as RTS +import Glean.Util.Some + +type T a = StateT TypecheckState (ExceptT (Doc ()) IO) a + +type ToRtsType = Schema.Type -> Maybe Type + +whenDebug :: T () -> T () +whenDebug act = do + lvl <- gets tcDebug + when lvl act + +freshTyVarInt :: T Int +freshTyVarInt = do + TypecheckState{tcNextTyVar = n} <- get + modify $ \s -> s { tcNextTyVar = n+1 } + return n + +freshTyVar :: T Type +freshTyVar = TyVar <$> freshTyVarInt + +data TcMode = TcModeQuery | TcModePredicate + deriving Eq + +data TcEnv = TcEnv + { tcEnvTypes :: HashMap TypeId TypeDetails + , tcEnvPredicates :: HashMap PredicateId PredicateDetails + } + +emptyTcEnv :: TcEnv +emptyTcEnv = TcEnv HashMap.empty HashMap.empty + +data TypecheckState = TypecheckState + { tcEnv :: TcEnv + , tcAngleVersion :: AngleVersion + , tcDebug :: !Bool + , tcRtsType :: ToRtsType + , tcNextVar :: {-# UNPACK #-} !Int + , tcNextTyVar :: {-# UNPACK #-} !Int + , tcScope :: HashMap Name Var + -- ^ Variables that we have types for, and have allocated a Var + , tcVisible :: HashSet Name + -- ^ Variables that are currently visible + , tcFree :: HashSet Name + -- ^ Variables that are mentioned only once + , tcUses :: HashSet Name + -- ^ Accumulates variables that appear in an ContextExpr context + , tcBindings :: HashSet Name + -- ^ Accumulates variables that appear in an ContextPat context + , tcMode :: TcMode + , tcDisplayOpts :: DisplayOpts + -- ^ Options for pretty-printing + , tcVars :: IntMap Var + , tcSubst :: IntMap Type + , tcPromote :: [(Type, Type, Some IsSrcSpan)] + } + +data TcOpts = TcOpts + { tcOptDebug :: !Config.DebugFlags + , tcOptAngleVersion :: !AngleVersion + } + +defaultTcOpts :: Config.DebugFlags -> AngleVersion -> TcOpts +defaultTcOpts debug v = TcOpts + { tcOptDebug = debug + , tcOptAngleVersion = v + } + +initialTypecheckState + :: TcEnv + -> TcOpts + -> ToRtsType + -> TcMode + -> TypecheckState +initialTypecheckState tcEnv TcOpts{..} rtsType mode = TypecheckState + { tcEnv = tcEnv + , tcAngleVersion = tcOptAngleVersion + , tcDebug = Config.tcDebug tcOptDebug + , tcRtsType = rtsType + , tcNextVar = 0 + , tcNextTyVar = 0 + , tcScope = HashMap.empty + , tcVisible = HashSet.empty + , tcFree = HashSet.empty + , tcUses = HashSet.empty + , tcBindings = HashSet.empty + , tcMode = mode + , tcDisplayOpts = defaultDisplayOpts + -- might make this configurable with flags later + , tcSubst = IntMap.empty + , tcVars = IntMap.empty + , tcPromote = [] + } + +getPredicateDetails :: PredicateId -> T PredicateDetails +getPredicateDetails pred = do + TcEnv{..} <- gets tcEnv + case HashMap.lookup pred tcEnvPredicates of + Nothing -> error $ "predicateKeyTYpe: " <> show (displayDefault pred) + Just d -> return d + +-- Smart constructor for wildcard patterns; replaces a wildcard that +-- matches the unit type with a concrete pattern. This is necessary +-- when we have an enum in an expression position: we can't translate +-- @nothing@ into @{ nothing = _ }@ because the wildcard would be +-- illegal in an expression. +mkWild :: Type -> TcPat +mkWild ty + | RecordTy [] <- derefType ty = RTS.Tuple [] + | otherwise = RTS.Ref (MatchWild ty) + +prettyError :: Doc () -> T a +prettyError = throwError + +prettyErrorIn :: IsSrcSpan s => SourcePat_ s p t -> Doc () -> T a +prettyErrorIn pat doc = prettyErrorAt (sourcePatSpan pat) doc + +prettyErrorAt :: IsSrcSpan span => span -> Doc () -> T a +prettyErrorAt span doc = prettyError $ vcat + [ pretty span + , doc + ] + +inPat :: (IsSrcSpan s) => SourcePat_ s p t -> T a -> T a +inPat pat = addErrSpan (sourcePatSpan pat) + +addErrSpan :: (IsSrcSpan s) => s -> T a -> T a +addErrSpan span act = do + act `catchError` \errDoc -> do + prettyError $ vcat + [ pretty span + , errDoc + ] diff --git a/glean/db/Glean/Query/Typecheck/Unify.hs b/glean/db/Glean/Query/Typecheck/Unify.hs new file mode 100644 index 000000000..7b22cfc05 --- /dev/null +++ b/glean/db/Glean/Query/Typecheck/Unify.hs @@ -0,0 +1,304 @@ +{- + Copyright (c) Meta Platforms, Inc. and affiliates. + All rights reserved. + + This source code is licensed under the BSD-style license found in the + LICENSE file in the root directory of this source tree. +-} + +module Glean.Query.Typecheck.Unify ( + unify, + apply, + zonkTcQuery, + zonkVars, + ) where + +import Control.Monad.Except +import Control.Monad.State +import qualified Data.IntMap as IntMap +import qualified Data.Map as Map +import Data.Maybe +import qualified Data.Text as Text +import Data.Text.Prettyprint.Doc hiding ((<>), enclose) + +import Glean.Angle.Types hiding (Type) +import Glean.Database.Schema.Types +import Glean.Display +import Glean.Query.Codegen.Types +import Glean.Query.Typecheck.Monad +import Glean.Query.Typecheck.Types +import Glean.RTS.Term hiding + (Tuple, ByteArray, String, Array, Nat, All) +import qualified Glean.RTS.Term as RTS +import Glean.RTS.Types as RTS +import Glean.Schema.Util + +unify :: Type -> Type -> T () +unify ByteTy ByteTy = return () +unify NatTy NatTy = return () +unify StringTy StringTy = return () +unify (ArrayTy t) (ArrayTy u) = unify t u +unify a@(RecordTy ts) b@(RecordTy us) + | length ts == length us = forM_ (zip ts us) $ + \(FieldDef f t, FieldDef g u) -> + if f == g || compareStructurally + then unify t u + else unifyError a b + where + isTuple = all (Text.isInfixOf "tuplefield" . fieldDefName) + compareStructurally = isTuple ts || isTuple us + -- structural equality for tuples by ignoring field names. +unify a@(SumTy ts) b@(SumTy us) + | length ts == length us = forM_ (zip ts us) $ + \(FieldDef f t, FieldDef g u) -> + if f /= g then unifyError a b else unify t u +unify (PredicateTy (PidRef p _)) (PredicateTy (PidRef q _)) + | p == q = return () +unify (NamedTy (ExpandedType n _)) (NamedTy (ExpandedType m _)) + | n == m = return () +unify (NamedTy (ExpandedType _ t)) u = unify t u +unify t (NamedTy (ExpandedType _ u)) = unify t u +unify (MaybeTy t) (MaybeTy u) = unify t u +unify (MaybeTy t) u@SumTy{} = unify (lowerMaybe t) u +unify t@SumTy{} (MaybeTy u) = unify t (lowerMaybe u) +unify (EnumeratedTy ns) (EnumeratedTy ms) | ns == ms = return () +unify (EnumeratedTy ns) u@SumTy{} = unify (lowerEnum ns) u +unify t@SumTy{} (EnumeratedTy ns) = unify t (lowerEnum ns) +unify BooleanTy BooleanTy = return () + +unify (TyVar x) (TyVar y) | x == y = return () +unify (TyVar x) t = extend x t +unify t (TyVar x) = extend x t + +unify (HasTy a ra x) (HasTy b rb y) = do + mapM_ (uncurry unify) $ Map.intersectionWith (,) a b + z <- freshTyVarInt + let all = HasTy (Map.union a b) (ra || rb) z + extend x all + extend y all + +unify a@(HasTy m _ x) b@(RecordTy fs) = do + forM_ fs $ \(FieldDef f ty) -> + case Map.lookup f m of + Nothing -> return () + Just ty' -> unify ty ty' + forM_ (Map.keys m) $ \n -> + when (n `notElem` map fieldDefName fs) $ + unifyError a b + extend x (RecordTy fs) + +unify a@(HasTy m False x) b@(SumTy fs) = do + forM_ fs $ \(FieldDef f ty) -> + case Map.lookup f m of + Nothing -> return () + Just ty' -> unify ty ty' + forM_ (Map.keys m) $ \n -> + when (n `notElem` map fieldDefName fs) $ + unifyError a b + extend x (SumTy fs) + +unify a@RecordTy{} b@HasTy{} = unify b a +unify a@SumTy{} b@HasTy{} = unify b a + +unify a b = unifyError a b + +unifyError :: Type -> Type -> T a +unifyError a b = do + opts <- gets tcDisplayOpts + prettyError $ vcat + [ "type error:" + , indent 2 (display opts a) + , "does not match:" + , indent 2 (display opts b) + ] + +extend :: Int -> Type -> T () +extend x t = do + t' <- apply t -- avoid creating a cycle in the substitution + if + | TyVar y <- t, y == x -> return () + | otherwise -> do + subst <- gets tcSubst + case IntMap.lookup x subst of + Just u -> unify t' u + Nothing -> + modify $ \s -> s{ tcSubst = IntMap.insert x t' (tcSubst s) } + +apply :: Type -> T Type +apply t = do + subst <- gets tcSubst + apply_ (\x -> return (IntMap.lookup x subst)) t + +zonkType :: Type -> T Type +zonkType t = do + subst <- gets tcSubst + opts <- gets tcDisplayOpts + let lookup x = case IntMap.lookup x subst of + Nothing -> prettyError $ "ambiguous type: " <> + display opts (TyVar x :: Type) + Just u -> return (Just u) + apply_ lookup t + +apply_ :: (Int -> T (Maybe Type)) -> Type -> T Type +apply_ lookup t = go t + where + go t = case t of + ByteTy -> return t + NatTy -> return t + StringTy -> return t + ArrayTy t -> ArrayTy <$> go t + RecordTy fs -> + fmap RecordTy $ forM fs $ \(FieldDef n t) -> + FieldDef n <$> go t + SumTy fs -> + fmap SumTy $ forM fs $ \(FieldDef n t) -> + FieldDef n <$> go t + PredicateTy{} -> return t + NamedTy{} -> return t + MaybeTy t -> MaybeTy <$> go t + EnumeratedTy{} -> return t + BooleanTy -> return t + TyVar x -> do + m <- lookup x + case m of + Nothing -> return t + Just u -> go u + HasTy _ _ x -> do + m <- lookup x + case m of + Nothing -> return t + Just u -> go u + SetTy t -> SetTy <$> go t + +zonkVars :: T () +zonkVars = do + vars <- gets tcVars + subst <- gets tcSubst + zonked <- forM vars $ \Var{..} -> do + let + lookup x = case IntMap.lookup x subst of + Nothing -> prettyError $ vcat + [ "variable " <> pretty var <> + " has unknown type" + , " try adding a type signature, like: " <> pretty var <> " : T" + ] + where var = fromMaybe (Text.pack ('_':show varId)) varOrigName + Just u -> return (Just u) + t <- apply_ lookup varType + return (Var { varType = t, ..}) + modify $ \s -> s { tcVars = zonked } + +zonkTcQuery :: TcQuery -> T TcQuery +zonkTcQuery (TcQuery ty k mv stmts) = + TcQuery + <$> zonkType ty + <*> zonkTcPat k + <*> mapM zonkTcPat mv + <*> mapM zonkTcStatement stmts + +zonkTcPat :: TcPat -> T TcPat +zonkTcPat p = case p of + RTS.Byte{} -> return p + RTS.Nat{} -> return p + RTS.Array ts -> RTS.Array <$> mapM zonkTcPat ts + RTS.ByteArray{} -> return p + RTS.Tuple ts -> RTS.Tuple <$> mapM zonkTcPat ts + RTS.Alt n t -> RTS.Alt n <$> zonkTcPat t + RTS.String{} -> return p + RTS.All ts -> RTS.All <$> mapM zonkTcPat ts + Ref (MatchExt (Typed ty (TcPromote inner e))) -> do + ty' <- zonkType ty + inner' <- zonkType inner + e' <- zonkTcPat e + case (ty', inner') of + (TyVar{}, _) -> error "zonkMatch: tyvar" + (_, TyVar{}) -> error "zonkMatch: tyvar" + (PredicateTy (PidRef _ ref), PredicateTy (PidRef _ ref')) + | ref == ref' -> return e' + (PredicateTy pidRef@(PidRef _ ref), _other) -> do + PredicateDetails{..} <- getPredicateDetails ref + let vpat = Ref (MatchWild predicateValueType) + return (Ref (MatchExt (Typed ty' + (TcFactGen pidRef e' vpat SeekOnAllFacts)))) + _ -> + return e' + Ref (MatchExt (Typed ty (TcStructPat fs))) -> do + ty' <- zonkType ty + case ty' of + RecordTy fields -> + fmap RTS.Tuple $ forM fields $ \(FieldDef f ty) -> + case [ p | (g,p) <- fs, f == g ] of + [] -> return (mkWild ty) + (x:_) -> zonkTcPat x + SumTy fields -> + case fs of + [(name,pat)] + | (_, n) :_ <- lookupField name fields -> do + pat' <- zonkTcPat pat + return (RTS.Alt n pat') + _other -> error $ "zonkTcPat: " <> show (displayDefault p) + _other -> do + opts <- gets tcDisplayOpts + prettyError $ + nest 4 $ vcat + [ "type error in pattern" + , "pattern: " <> display opts p + , "expected type: " <> display opts ty + ] + + Ref m -> Ref <$> zonkMatch m + +zonkMatch :: Match (Typed TcTerm) Var -> T (Match (Typed TcTerm) Var) +zonkMatch m = case m of + MatchWild ty -> MatchWild <$> zonkType ty + MatchNever ty -> MatchNever <$> zonkType ty + MatchFid{} -> return m + MatchBind v -> MatchBind <$> var v + MatchVar v -> MatchVar <$> var v + MatchAnd a b -> MatchAnd <$> zonkTcPat a <*> zonkTcPat b + MatchPrefix s pat -> MatchPrefix s <$> zonkTcPat pat + MatchArrayPrefix ty ts -> + MatchArrayPrefix <$> zonkType ty <*> mapM zonkTcPat ts + MatchExt (Typed ty e) -> + MatchExt <$> (Typed <$> zonkType ty <*> zonkTcTerm e) + where + var (Var _ n _) = do + vars <- gets tcVars + case IntMap.lookup n vars of + Nothing -> error "zonkMatch" + Just v -> return v + +zonkTcTerm :: TcTerm -> T TcTerm +zonkTcTerm t = case t of + TcOr a b -> TcOr <$> zonkTcPat a <*> zonkTcPat b + TcFactGen pid k v sec -> + TcFactGen pid <$> zonkTcPat k <*> zonkTcPat v <*> pure sec + TcElementsOfArray a -> TcElementsOfArray <$> zonkTcPat a + TcQueryGen q -> TcQueryGen <$> zonkTcQuery q + TcNegation stmts -> TcNegation <$> mapM zonkTcStatement stmts + TcPrimCall op args -> TcPrimCall op <$> mapM zonkTcPat args + TcIf (Typed ty cond) th el -> + TcIf + <$> (Typed <$> zonkType ty <*> zonkTcPat cond) + <*> zonkTcPat th + <*> zonkTcPat el + TcDeref ty valTy p -> + TcDeref <$> zonkType ty <*> zonkType valTy <*> zonkTcPat p + TcFieldSelect (Typed ty p) f -> + TcFieldSelect + <$> (Typed <$> zonkType ty <*> zonkTcPat p) + <*> pure f + TcAltSelect (Typed ty p) f -> + TcAltSelect + <$> (Typed <$> zonkType ty <*> zonkTcPat p) + <*> pure f + TcElements p -> TcElements <$> zonkTcPat p + TcPromote{} -> error "zonkTcTerm: TcPromote" -- handled in zonkTcPat + TcStructPat{} -> error "zonkTcTerm: TcStructPat" -- handled in zonkTcPat + +zonkTcStatement :: TcStatement -> T TcStatement +zonkTcStatement (TcStatement ty l r) = + TcStatement + <$> zonkType ty + <*> zonkTcPat l + <*> zonkTcPat r