Skip to content

Commit

Permalink
apply filter-branch changes to working/staged changes (#7925)
Browse files Browse the repository at this point in the history
  • Loading branch information
jycor authored May 30, 2024
1 parent 13cec7d commit 8c9cfe7
Show file tree
Hide file tree
Showing 4 changed files with 393 additions and 141 deletions.
179 changes: 119 additions & 60 deletions go/cmd/dolt/commands/filter-branch.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ import (
)

const (
filterDbName = "filterDB"
branchesFlag = "branches"
filterDbName = "filterDB"
branchesFlag = "branches"
uncommittedFlag = "apply-to-uncommitted"
)

var filterBranchDocs = cli.CommandDocumentationContent{
Expand Down Expand Up @@ -82,6 +83,7 @@ func (cmd FilterBranchCmd) ArgParser() *argparser.ArgParser {
ap := argparser.NewArgParserWithVariableArgs(cmd.Name())
ap.SupportsFlag(cli.VerboseFlag, "v", "logs more information")
ap.SupportsFlag(branchesFlag, "b", "filter all branches")
ap.SupportsFlag(uncommittedFlag, "", "apply changes to uncommitted tables")
ap.SupportsFlag(cli.AllFlag, "a", "filter all branches and tags")
ap.SupportsFlag(continueFlag, "c", "log a warning and continue if any errors occur executing statements")
ap.SupportsString(QueryFlag, "q", "queries", "Queries to run, separated by semicolons. If not provided, queries are read from STDIN.")
Expand Down Expand Up @@ -118,67 +120,133 @@ func (cmd FilterBranchCmd) Exec(ctx context.Context, commandStr string, args []s
queryString = string(queryStringBytes)
}

replay := func(ctx context.Context, commit, _, _ *doltdb.Commit) (doltdb.RootValue, error) {
var cmHash, before hash.Hash
var root doltdb.RootValue
if verbose {
var err error
cmHash, err = commit.HashOf()
if err != nil {
return nil, err
}
nerf, err := getNerf(ctx, dEnv, apr)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}

cli.Printf("processing commit %s\n", cmHash.String())
rootReplayer := &workingSetReplayer{
dEnv: dEnv,
queryString: queryString,
verbose: verbose,
continueOnErr: continueOnErr,
}

root, err = commit.GetRootValue(ctx)
if err != nil {
return nil, err
}
before, err = root.HashOf()
if err != nil {
return nil, err
}
}
commitReplayer := &commitReplayer{
dEnv: dEnv,
queryString: queryString,
verbose: verbose,
continueOnErr: continueOnErr,
}

updatedRoot, err := processFilterQuery(ctx, dEnv, commit, queryString, verbose, continueOnErr)
applyUncommitted := apr.Contains(uncommittedFlag)
switch {
case apr.Contains(branchesFlag):
err = rebase.AllBranches(ctx, dEnv, applyUncommitted, commitReplayer, rootReplayer, nerf)
case apr.Contains(cli.AllFlag):
err = rebase.AllBranchesAndTags(ctx, dEnv, applyUncommitted, commitReplayer, rootReplayer, nerf)
default:
err = rebase.CurrentBranch(ctx, dEnv, applyUncommitted, commitReplayer, rootReplayer, nerf)
}
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}

return 0
}

// workingSetReplayer replays working set root values, rebasing them with a specific query, and returns the updated root value
type workingSetReplayer struct {
dEnv *env.DoltEnv
queryString string
verbose bool
continueOnErr bool
}

var _ rebase.RootReplayer = &workingSetReplayer{}

// ReplayRoot implements the RootReplayer interface
func (r *workingSetReplayer) ReplayRoot(ctx context.Context, root, _, _ doltdb.RootValue) (doltdb.RootValue, error) {
rootHash, err := root.HashOf()
if err != nil {
return nil, err
}
rootHashStr := rootHash.String()
if r.verbose {
cli.Printf("processing commit %s\n", rootHashStr)
}

updatedRoot, err := processFilterQuery(ctx, r.dEnv, root, rootHashStr, r.queryString, r.verbose, r.continueOnErr)
if err != nil {
return nil, err
}

if r.verbose {
var before, after hash.Hash
before, err = root.HashOf()
if err != nil {
return nil, err
}

if verbose {
after, err := updatedRoot.HashOf()
if err != nil {
return nil, err
}
if before != after {
cli.Printf("updated commit %s (root: %s -> %s)\n",
cmHash.String(), before.String(), after.String())
} else {
cli.Printf("no changes to commit %s", cmHash.String())
}
after, err = updatedRoot.HashOf()
if err != nil {
return nil, err
}
if before != after {
cli.Printf("updated commit %s (root: %s -> %s)\n", rootHashStr, before.String(), after.String())
} else {
cli.Printf("no changes to commit %s", rootHashStr)
}
return updatedRoot, nil
}
return updatedRoot, nil
}

nerf, err := getNerf(ctx, dEnv, apr)
// commitReplayer replays a specific commits, rebasing it with a specific query, and returns the updated root value
type commitReplayer struct {
dEnv *env.DoltEnv
queryString string
verbose bool
continueOnErr bool
}

var _ rebase.CommitReplayer = &commitReplayer{}

// ReplayCommit implements the CommitReplayer interface
func (c *commitReplayer) ReplayCommit(ctx context.Context, commit, _, _ *doltdb.Commit) (doltdb.RootValue, error) {
root, err := commit.GetRootValue(ctx)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
return nil, err
}

switch {
case apr.Contains(branchesFlag):
err = rebase.AllBranches(ctx, dEnv, replay, nerf)
case apr.Contains(cli.AllFlag):
err = rebase.AllBranchesAndTags(ctx, dEnv, replay, nerf)
default:
err = rebase.CurrentBranch(ctx, dEnv, replay, nerf)
cmHash, err := commit.HashOf()
if err != nil {
return nil, err
}
cmHashStr := cmHash.String()
if c.verbose {
cli.Printf("processing commit %s\n", cmHashStr)
}

updatedRoot, err := processFilterQuery(ctx, c.dEnv, root, cmHashStr, c.queryString, c.verbose, c.continueOnErr)
if err != nil {
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
return nil, err
}

return 0
if c.verbose {
var before, after hash.Hash
before, err = root.HashOf()
if err != nil {
return nil, err
}
after, err = updatedRoot.HashOf()
if err != nil {
return nil, err
}
if before != after {
cli.Printf("updated commit %s (root: %s -> %s)\n", cmHashStr, before.String(), after.String())
} else {
cli.Printf("no changes to commit %s", cmHashStr)
}
}
return updatedRoot, nil
}

func getNerf(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) (rebase.NeedsRebaseFn, error) {
Expand Down Expand Up @@ -208,8 +276,8 @@ func getNerf(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResu
return rebase.StopAtCommit(cm), nil
}

func processFilterQuery(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commit, query string, verbose bool, continueOnErr bool) (doltdb.RootValue, error) {
sqlCtx, eng, err := rebaseSqlEngine(ctx, dEnv, cm)
func processFilterQuery(ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue, cmHashStr string, query string, verbose bool, continueOnErr bool) (doltdb.RootValue, error) {
sqlCtx, eng, err := rebaseSqlEngine(ctx, dEnv, root)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -246,11 +314,7 @@ func processFilterQuery(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commi
if err != nil {
if continueOnErr {
if verbose {
cmHash, cmErr := cm.HashOf()
if cmErr != nil {
return nil, err
}
cli.PrintErrf("error encountered processing commit %s (continuing): %s\n", cmHash.String(), err.Error())
cli.PrintErrf("error encountered processing commit %s (continuing): %s\n", cmHashStr, err.Error())
}
} else {
return nil, err
Expand All @@ -271,7 +335,7 @@ func processFilterQuery(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commi
// The SQL engine returned has transactions disabled. This is to prevent transactions starts from overwriting the root
// we set manually with the one at the working set of the HEAD being rebased.
// Some functionality will not work on this kind of engine, e.g. many DOLT_ functions.
func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commit) (*sql.Context, *engine.SqlEngine, error) {
func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootValue) (*sql.Context, *engine.SqlEngine, error) {
tmpDir, err := dEnv.TempTableFilesDir()
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -309,11 +373,6 @@ func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, cm *doltdb.Commit)
parallelism := runtime.GOMAXPROCS(0)
azr := analyzer.NewBuilder(pro).WithParallelism(parallelism).Build()

root, err := cm.GetRootValue(ctx)
if err != nil {
return nil, nil, err
}

err = db.SetRoot(sqlCtx, root)
if err != nil {
return nil, nil, err
Expand Down
2 changes: 1 addition & 1 deletion go/libraries/doltcore/doltdb/doltdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -1231,7 +1231,7 @@ func (ddb *DoltDB) NewBranchAtCommit(ctx context.Context, branchRef ref.DoltRef,
var ws *WorkingSet
var currWsHash hash.Hash
ws, err = ddb.ResolveWorkingSet(ctx, wsRef)
if err == ErrWorkingSetNotFound {
if errors.Is(err, ErrWorkingSetNotFound) {
ws = EmptyWorkingSet(wsRef)
} else if err != nil {
return err
Expand Down
Loading

0 comments on commit 8c9cfe7

Please sign in to comment.