diff --git a/src/extension.ts b/src/extension.ts index 69f822d..238969d 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -32,7 +32,7 @@ export function activate(context: vscode.ExtensionContext) { var disposable_createQueryTemplate = vscode.commands.registerCommand('dsk.createQueryTemplate', createQueryTemplate); context.subscriptions.push(disposable_createQueryTemplate); - var newQueryOption = async () => { + var newQueryOption = async (profile?: azdata.IConnectionProfile, context?: azdata.ObjectExplorerContext) => { let scriptText:string = ""; const workbenchConfig = vscode.workspace.getConfiguration('newquerytemplate'); let queryTemplateArray = new Array(); @@ -44,11 +44,10 @@ export function activate(context: vscode.ExtensionContext) { `; }); } - - new placeScript().placescript(scriptText); + new placeScript().placescript(scriptText, profile, context); }; vscode.commands.registerCommand('dsk.newqueryoption', newQueryOption); - azdata.tasks.registerTask('dsk.newqueryoption', newQueryOption); + azdata.tasks.registerTask('dsk.newqueryoption', (profile?: azdata.IConnectionProfile, context?: azdata.ObjectExplorerContext) => newQueryOption(profile, context)); var useDatabaseCmd = () => { diff --git a/src/placescript.ts b/src/placescript.ts index d1c7cb3..eba06c2 100644 --- a/src/placescript.ts +++ b/src/placescript.ts @@ -4,18 +4,46 @@ import * as sqlops from 'azdata'; import * as vscode from 'vscode'; export class placeScript { + + private connectionId: string = ''; + private dbName: string = ''; + // places scriptText into fileName editor with current connection - public async placescript(scriptText) { + public async placescript(scriptText:string, context?:sqlops.IConnectionProfile, oecontext?: sqlops.ObjectExplorerContext) { try { - let connection = await sqlops.connection.getCurrentConnection(); + var connection; + vscode.window.showInformationMessage('starting placescript'); + if (context && context.id) { + this.connectionId = context.id; + this.dbName = context.databaseName; + } else if (oecontext) { + connection = oecontext.connectionProfile; + this.connectionId = connection.id; + } else { + connection = await sqlops.connection.getCurrentConnection(); + if (connection) { + this.connectionId = connection.connectionId; + this.dbName = connection.databaseName; + } + } let doc = await vscode.workspace.openTextDocument({language: 'sql'}); let editor = await vscode.window.showTextDocument(doc, 1, false); editor.edit(edit => { edit.insert(new vscode.Position(0, 0), scriptText); }); - - if (connection.connectionId) { - await sqlops.queryeditor.connect(doc.uri.toString(), connection.connectionId); + if ((context || connection) && this.connectionId) { + if (this.dbName !== '') { + var providerName:string; + if (context) { + providerName = context.providerName; + } else if (connection) { + providerName = connection.providerId; + } + let dProvider = await sqlops.dataprotocol.getProvider(providerName, sqlops.DataProviderType.ConnectionProvider); + let connectionUri = await sqlops.connection.getUriForConnection(this.connectionId); + await dProvider.changeDatabase(connectionUri,this.dbName); + } + await sqlops.queryeditor.connect(doc.uri.toString(), this.connectionId); } } catch (err) { vscode.window.showErrorMessage(err); @@ -36,7 +64,7 @@ export class placeScript { editor.edit(edit => { edit.insert(new vscode.Position(0, 0), scriptText); }); - if (connection.connectionId) { + if (connection.connectionId) { await sqlops.queryeditor.connect(doc.uri.toString(), connection.connectionId); } } catch (err) {