Skip to content

Commit

Permalink
fix(taskprocessing): select preferred provider when running sync task…
Browse files Browse the repository at this point in the history
…, fix task type values according to preferred provider

Signed-off-by: Julien Veyssier <[email protected]>
  • Loading branch information
julien-nc committed Aug 12, 2024
1 parent b34edf2 commit dbab2a8
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
18 changes: 9 additions & 9 deletions lib/private/TaskProcessing/Manager.php
Original file line number Diff line number Diff line change
Expand Up @@ -649,39 +649,39 @@ public function getProviders(): array {
return $this->providers;
}

public function getPreferredProvider(string $taskType) {
public function getPreferredProvider(string $taskTypeId) {
try {
$preferences = json_decode($this->config->getAppValue('core', 'ai.taskprocessing_provider_preferences', 'null'), associative: true, flags: JSON_THROW_ON_ERROR);
$providers = $this->getProviders();
if (isset($preferences[$taskType])) {
$provider = current(array_values(array_filter($providers, fn ($provider) => $provider->getId() === $preferences[$taskType])));
if (isset($preferences[$taskTypeId])) {
$provider = current(array_values(array_filter($providers, fn ($provider) => $provider->getId() === $preferences[$taskTypeId])));
if ($provider !== false) {
return $provider;
}
}
// By default, use the first available provider
foreach ($providers as $provider) {
if ($provider->getTaskTypeId() === $taskType) {
if ($provider->getTaskTypeId() === $taskTypeId) {
return $provider;
}
}
} catch (\JsonException $e) {
$this->logger->warning('Failed to parse provider preferences while getting preferred provider for task type ' . $taskType, ['exception' => $e]);
$this->logger->warning('Failed to parse provider preferences while getting preferred provider for task type ' . $taskTypeId, ['exception' => $e]);
}
throw new \OCP\TaskProcessing\Exception\Exception('No matching provider found');
}

public function getAvailableTaskTypes(): array {
if ($this->availableTaskTypes === null) {
$taskTypes = $this->_getTaskTypes();
$providers = $this->getProviders();

$availableTaskTypes = [];
foreach ($providers as $provider) {
if (!isset($taskTypes[$provider->getTaskTypeId()])) {
foreach ($taskTypes as $taskType) {
try {
$provider = $this->getPreferredProvider($taskType->getId());
} catch (\OCP\TaskProcessing\Exception\Exception $e) {
continue;
}
$taskType = $taskTypes[$provider->getTaskTypeId()];
try {
$availableTaskTypes[$provider->getTaskTypeId()] = [
'name' => $taskType->getName(),
Expand Down
9 changes: 7 additions & 2 deletions lib/private/TaskProcessing/SynchronousBackgroundJob.php
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@ protected function run($argument) {
if (!$provider instanceof ISynchronousProvider) {
continue;
}
$taskType = $provider->getTaskTypeId();
$taskTypeId = $provider->getTaskTypeId();
// only use this provider if it is the preferred one
$preferredProvider = $this->taskProcessingManager->getPreferredProvider($taskTypeId);
if ($provider->getId() !== $preferredProvider->getId()) {
continue;
}
try {
$task = $this->taskProcessingManager->getNextScheduledTask([$taskType]);
$task = $this->taskProcessingManager->getNextScheduledTask([$taskTypeId]);
} catch (NotFoundException $e) {
continue;
} catch (Exception $e) {
Expand Down
4 changes: 2 additions & 2 deletions lib/public/TaskProcessing/IManager.php
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ public function hasProviders(): bool;
public function getProviders(): array;

/**
* @param string $taskType
* @param string $taskTypeId
* @return IProvider
* @throws Exception
* @since 30.0.0
*/
public function getPreferredProvider(string $taskType);
public function getPreferredProvider(string $taskTypeId);

/**
* @return array<array-key,array{name: string, description: string, inputShape: ShapeDescriptor[], inputShapeEnumValues: ShapeEnumValue[][], inputShapeDefaults: array<array-key, numeric|string>, optionalInputShape: ShapeDescriptor[], optionalInputShapeEnumValues: ShapeEnumValue[][], optionalInputShapeDefaults: array<array-key, numeric|string>, outputShape: ShapeDescriptor[], outputShapeEnumValues: ShapeEnumValue[][], optionalOutputShape: ShapeDescriptor[], optionalOutputShapeEnumValues: ShapeEnumValue[][]}>
Expand Down

0 comments on commit dbab2a8

Please sign in to comment.