diff --git a/package.json b/package.json index 564df17..537cc23 100644 --- a/package.json +++ b/package.json @@ -90,6 +90,12 @@ "type": "string" }, "type": "array" + }, + "tach.configuration": { + "default": "", + "description": "Path to a `tach.toml` file to use for configuration. By default, the extension will mirror the behavior that the `tach` CLI would have.", + "scope": "resource", + "type": "string" } } }, diff --git a/src/common/server.ts b/src/common/server.ts index e42e511..20e1d19 100644 --- a/src/common/server.ts +++ b/src/common/server.ts @@ -14,9 +14,30 @@ import { traceError, traceInfo, traceVerbose } from './log/logging'; import { getExtensionSettings, getGlobalSettings, getWorkspaceSettings, ISettings } from './settings'; import { getLSClientTraceLevel, getProjectRoot } from './utilities'; import { isVirtualWorkspace } from './vscodeapi'; +import { execFile } from 'child_process'; +import { supportsCustomConfig, VersionInfo } from './version'; export type IInitOptions = { settings: ISettings[]; globalSettings: ISettings }; +function executeCommand(file: string, args: string[] = []): Promise { + return new Promise((resolve, reject) => { + execFile(file, args, (error, stdout, stderr) => { + if (error) { + reject(new Error(stderr || error.message)); + } else { + resolve(stdout); + } + }); + }); +} + +async function getTachVersion(pythonExecutable: string): Promise { + const stdout = await executeCommand(pythonExecutable, ["-m", "tach", "--version"]); + const version = stdout.trim().split(" ")[1]; + const [major, minor, patch] = version.split(".").map((x) => parseInt(x, 10)); + return new VersionInfo(major, minor, patch); + } + async function createServer( settings: ISettings, serverId: string, @@ -34,7 +55,19 @@ async function createServer( newEnv.PYTHONPATH = BUNDLED_PYTHON_LIBS_DIR; } + const args = settings.interpreter.slice(1).concat(["-m", "tach", "server"]); + + if (settings.configuration) { + const version = await getTachVersion(command); + if (!supportsCustomConfig(version)) { + traceError(`Server: Tach version ${version.toString()} does not support custom configuration files.`); + } else { + traceInfo(`Server: Using custom configuration file: ${settings.configuration}`); + args.push("-c", settings.configuration); + } + } + traceInfo(`Server run command: ${[command, ...args].join(' ')}`); const serverOptions: ServerOptions = { diff --git a/src/common/settings.ts b/src/common/settings.ts index c46b309..cf4121c 100644 --- a/src/common/settings.ts +++ b/src/common/settings.ts @@ -12,6 +12,7 @@ export interface ISettings { workspace: string; interpreter: string[]; importStrategy: ImportStrategy; + configuration: string | null; } export function getExtensionSettings(namespace: string, includeInterpreter?: boolean): Promise { @@ -65,6 +66,7 @@ export async function getWorkspaceSettings( workspace: workspace.uri.toString(), interpreter: resolveVariables(interpreter, workspace), importStrategy: config.get(`importStrategy`) ?? 'fromEnvironment', + configuration: config.get(`configuration`) ?? null, }; return workspaceSetting; } @@ -90,6 +92,7 @@ export async function getGlobalSettings(namespace: string, includeInterpreter?: workspace: process.cwd(), interpreter: interpreter, importStrategy: getGlobalValue(config, 'importStrategy', 'useBundled'), + configuration: getGlobalValue(config, 'configuration', null), }; return setting; } diff --git a/src/common/version.ts b/src/common/version.ts new file mode 100644 index 0000000..bd47c50 --- /dev/null +++ b/src/common/version.ts @@ -0,0 +1,33 @@ +export class VersionInfo { + constructor( + public major: number, + public minor: number, + public patch: number + ) {} + + toString(): string { + return `${this.major}.${this.minor}.${this.patch}`; + } +} + +function versionGte(a: VersionInfo, b: VersionInfo): boolean { + if (a.major !== b.major) { + return a.major > b.major; + } + if (a.minor !== b.minor) { + return a.minor > b.minor; + } + return a.patch >= b.patch; +} + +const MIN_VERSION_WITH_CONFIG = { + major: 0, + minor: 24, + patch: 0, +}; + + + +export function supportsCustomConfig(version: VersionInfo): boolean { + return versionGte(version, MIN_VERSION_WITH_CONFIG); +}