Skip to content

Commit

Permalink
Schema support for Sqlite backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
arrowd committed Feb 1, 2023
1 parent 8970320 commit ddedf5d
Showing 1 changed file with 54 additions and 27 deletions.
81 changes: 54 additions & 27 deletions persistent-sqlite/Database/Persist/Sqlite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ wrapConnectionInfo connInfo conn logFunc = do
, connCommit = helper "COMMIT"
, connRollback = ignoreExceptions . helper "ROLLBACK"
, connEscapeFieldName = escape . unFieldNameDB
, connEscapeTableName = escape . unEntityNameDB . getEntityDBName
, connEscapeTableName = escapeS . schemaNamePair
, connEscapeRawName = escape
, connNoLimit = "LIMIT -1"
, connRDBMS = "sqlite"
Expand Down Expand Up @@ -358,7 +358,7 @@ insertSql' ent vals =
ISRManyKeys sql vals
where sql = T.concat
[ "INSERT INTO "
, escapeE $ getEntityDBName ent
, escapeS $ schemaNamePair ent
, "("
, T.intercalate "," $ map (escapeF . fieldDB) cols
, ") VALUES("
Expand All @@ -372,12 +372,12 @@ insertSql' ent vals =
[ "SELECT "
, escapeF $ fieldDB fd
, " FROM "
, escapeE $ getEntityDBName ent
, escapeS $ schemaNamePair ent
, " WHERE _ROWID_=last_insert_rowid()"
]
ins = T.concat
[ "INSERT INTO "
, escapeE $ getEntityDBName ent
, escapeS $ schemaNamePair ent
, if null cols
then " VALUES(null)"
else T.concat
Expand Down Expand Up @@ -434,8 +434,14 @@ showSqlType SqlBlob = "BLOB"
showSqlType SqlBool = "BOOLEAN"
showSqlType (SqlOther t) = t

type SchemaEntityName = (Text, EntityNameDB)

sqliteMkColumns :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef])
sqliteMkColumns allDefs t = mkColumns allDefs t emptyBackendSpecificOverrides
sqliteMkColumns allDefs t = mkColumns allDefs t sqliteSpecificOverrides

sqliteSpecificOverrides :: BackendSpecificOverrides
sqliteSpecificOverrides = setBackendSpecificSchemaEntityName (\schema name -> unescapeS $ schemaNamePair' schema name)
$ emptyBackendSpecificOverrides

migrate'
:: [EntityDef]
Expand All @@ -444,9 +450,9 @@ migrate'
-> IO (Either [Text] CautiousMigration)
migrate' allDefs getter val = do
let (cols, uniqs, fdefs) = sqliteMkColumns allDefs val
let newSql = mkCreateTable False def (filter (not . safeToRemove val . cName) cols, uniqs, fdefs)
let newSql = mkCreateTable def (filter (not . safeToRemove val . cName) cols, uniqs, fdefs)
stmt <- getter "SELECT sql FROM sqlite_master WHERE type='table' AND name=?"
oldSql' <- with (stmtQuery stmt [PersistText $ unEntityNameDB table])
oldSql' <- with (stmtQuery stmt [PersistText table])
(\src -> runConduit $ src .| go)
case oldSql' of
Nothing -> return $ Right [(False, newSql)]
Expand All @@ -458,7 +464,7 @@ migrate' allDefs getter val = do
return $ Right sql
where
def = val
table = getEntityDBName def
table = unEntityNameDB $ unescapeS $ schemaNamePair def
go = do
x <- CL.head
case x of
Expand Down Expand Up @@ -490,7 +496,7 @@ mockMigration mig = do
, connCommit = helper "COMMIT"
, connRollback = ignoreExceptions . helper "ROLLBACK"
, connEscapeFieldName = escape . unFieldNameDB
, connEscapeTableName = escape . unEntityNameDB . getEntityDBName
, connEscapeTableName = escapeS . schemaNamePair
, connEscapeRawName = escape
, connNoLimit = "LIMIT -1"
, connRDBMS = "sqlite"
Expand Down Expand Up @@ -528,7 +534,7 @@ getCopyTable :: [EntityDef]
-> EntityDef
-> IO [(Bool, Text)]
getCopyTable allDefs getter def = do
stmt <- getter $ T.concat [ "PRAGMA table_info(", escapeE table, ")" ]
stmt <- getter $ T.concat [ "PRAGMA table_info(", escapeS table, ")" ]
oldCols' <- with (stmtQuery stmt []) (\src -> runConduit $ src .| getCols)
let oldCols = map FieldNameDB oldCols'
let newCols = filter (not . safeToRemove def) $ map cName cols
Expand All @@ -549,42 +555,44 @@ getCopyTable allDefs getter def = do
names <- getCols
return $ name : names
Just y -> error $ "Invalid result from PRAGMA table_info: " ++ show y
table = getEntityDBName def
tableTmp = EntityNameDB $ unEntityNameDB table <> "_backup"
defTmp = setEntityDBSchema (Just "temp")
$ setEntityDBName (escapeWith (EntityNameDB . (<> "_backup")) $ getEntityDBName def)
def
table = schemaNamePair def
tableTmp = schemaNamePair defTmp
(cols, uniqs, fdef) = sqliteMkColumns allDefs def
cols' = filter (not . safeToRemove def . cName) cols
newSql = mkCreateTable False def (cols', uniqs, fdef)
tmpSql = mkCreateTable True (setEntityDBName tableTmp def) (cols', uniqs, [])
dropTmp = "DROP TABLE " <> escapeE tableTmp
dropOld = "DROP TABLE " <> escapeE table
newSql = mkCreateTable def (cols', uniqs, fdef)
tmpSql = mkCreateTable defTmp (cols', uniqs, [])
dropTmp = "DROP TABLE " <> escapeS tableTmp
dropOld = "DROP TABLE " <> escapeS table
copyToTemp common = T.concat
[ "INSERT INTO "
, escapeE tableTmp
, escapeS tableTmp
, "("
, T.intercalate "," $ map escapeF common
, ") SELECT "
, T.intercalate "," $ map escapeF common
, " FROM "
, escapeE table
, escapeS table
]
copyToFinal newCols = T.concat
[ "INSERT INTO "
, escapeE table
, escapeS table
, " SELECT "
, T.intercalate "," $ map escapeF newCols
, " FROM "
, escapeE tableTmp
, escapeS tableTmp
]

mkCreateTable :: Bool -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -> Text
mkCreateTable isTemp entity (cols, uniqs, fdefs) =
mkCreateTable :: EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -> Text
mkCreateTable entity (cols, uniqs, fdefs) =
T.concat (header <> columns <> footer)
where
isTemp = getEntityDBSchema entity == Just "temp"
header =
[ "CREATE"
, if isTemp then " TEMP" else ""
, " TABLE "
, escapeE $ getEntityDBName entity
[ "CREATE TABLE "
, escapeS $ schemaNamePair entity
, "("
]

Expand Down Expand Up @@ -678,6 +686,16 @@ sqlUnique (UniqueDef _ cname cols _) = T.concat
, ")"
]

schemaNamePair :: EntityDef -> SchemaEntityName
schemaNamePair ent = schemaNamePair' (getEntityDBSchema ent) (getEntityDBName ent)

schemaNamePair' :: Maybe Text -> EntityNameDB -> SchemaEntityName
schemaNamePair' mbSchema entName = case mbSchema of
Nothing -> ("", entName)
Just "main" -> ("", entName)
Just "temp" -> ("temp", entName)
Just schema -> ("", EntityNameDB $ (schema <> "_") <> unEntityNameDB entName)

escapeC :: ConstraintNameDB -> Text
escapeC = escapeWith escape

Expand All @@ -687,6 +705,15 @@ escapeE = escapeWith escape
escapeF :: FieldNameDB -> Text
escapeF = escapeWith escape

escapeS :: (Text, EntityNameDB) -> Text
escapeS ("", entDBName) = escapeE entDBName
-- no need to escape schema as it is either "" or "temp"
escapeS (schema, entDBName) = schema <> "." <> escapeE entDBName

unescapeS :: (Text, EntityNameDB) -> EntityNameDB
unescapeS ("", entDBName) = entDBName
unescapeS (schema, entDBName) = EntityNameDB $ escapeWith ((schema <> ".") <>) entDBName

escape :: Text -> Text
escape s =
T.concat [q, T.concatMap go s, q]
Expand All @@ -713,7 +740,7 @@ putManySql' conflictColumns fields ent n = q
fieldDbToText = escapeF . fieldDB
mkAssignment f = T.concat [f, "=EXCLUDED.", f]

table = escapeE . getEntityDBName $ ent
table = escapeS . schemaNamePair $ ent
columns = Util.commaSeparated $ map fieldDbToText fields
placeholders = map (const "?") fields
updates = map (mkAssignment . fieldDbToText) fields
Expand Down

0 comments on commit ddedf5d

Please sign in to comment.