Skip to content

Commit

Permalink
Added anyhow
Browse files Browse the repository at this point in the history
Signed-off-by: Ketan Umare <[email protected]>
  • Loading branch information
kumare3 committed May 10, 2024
1 parent 0027fe1 commit bdf27de
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 28 deletions.
45 changes: 31 additions & 14 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import math
import os # TODO: use flytekit logger
import typing
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Set, Union, cast

Expand All @@ -13,6 +14,8 @@
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.task import ReferenceTask
from flytekit.core.type_engine import TypeEngine
from flytekit.core.utils import timeit
from flytekit.exceptions import scopes as exception_scopes
from flytekit.loggers import logger
Expand All @@ -25,14 +28,14 @@

class ArrayNodeMapTask(PythonTask):
def __init__(
self,
# TODO: add support for other Flyte entities
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: Optional[float] = None,
bound_inputs: Optional[Set[str]] = None,
**kwargs,
self,
# TODO: add support for other Flyte entities
python_function_task: Union[PythonFunctionTask, PythonInstanceTask, functools.partial],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: Optional[float] = None,
bound_inputs: Optional[Set[str]] = None,
**kwargs,
):
"""
:param python_function_task: The task to be executed in parallel
Expand All @@ -55,7 +58,21 @@ def __init__(

# TODO: add support for other Flyte entities
if not (isinstance(actual_task, PythonFunctionTask) or isinstance(actual_task, PythonInstanceTask)):
raise ValueError("Only PythonFunctionTask and PythonInstanceTask are supported in map tasks.")
from flytekit.remote import FlyteTask
if isinstance(actual_task, FlyteTask):
# TODO This hack has to be done for remote tasks
TypeEngine.guess_python_types(actual_task.interface)
collection_interface = transform_interface_to_list_interface(
actual_task.interface, bound_inputs, False
)
super().__init__(name=f"{actual_task.name}-arrnode", raw_interface=actual_task.interface,
task_type=actual_task.type, task_config=None, task_type_version=1, **kwargs)
return
raise ValueError("Only PythonFunctionTask | PythonInstanceTask | FlyteTask (remote) are supported in "
"map tasks.")
if isinstance(python_function_task, ReferenceTask):
raise AssertionError(
"ReferenceTasks cannot be used in map tasks. Use PythonFunctionTask OR flyteremote.fetch instead.")

n_outputs = len(actual_task.python_interface.outputs)
if n_outputs > 1:
Expand Down Expand Up @@ -313,11 +330,11 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
task_function: PythonFunctionTask,
concurrency: Optional[int] = None,
# TODO why no min_successes?
min_success_ratio: float = 1.0,
**kwargs,
task_function: typing.Union[PythonFunctionTask, PythonInstanceTask, functools.partial, "FlyteTask"],
concurrency: Optional[int] = None,
# TODO why no min_successes?
min_success_ratio: float = 1.0,
**kwargs,
):
"""Map task that uses the ``ArrayNode`` construct..
Expand Down
6 changes: 5 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,14 @@ def __init__(
disable_deck (bool): (deprecated) If true, this task will not output deck html file
enable_deck (bool): If true, this task will output deck html file
"""
if "raw_interface" in kwargs:
raw_interface = kwargs.pop("raw_interface")
else:
raw_interface = transform_interface_to_typed_interface(interface, allow_partial_artifact_id_binding=True)
super().__init__(
task_type=task_type,
name=name,
interface=transform_interface_to_typed_interface(interface, allow_partial_artifact_id_binding=True),
interface=raw_interface,
**kwargs,
)
self._python_interface = interface if interface else Interface()
Expand Down
6 changes: 4 additions & 2 deletions rust/flyrs/src/distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ use object_store::{ObjectStore, parse_url};
use object_store::path::Path;
use tar::Archive;
use url::Url;
use anyhow::{bail, Result};


#[tracing::instrument(err)]
pub async fn download_unarchive_distribution(src: &Url, dst: &String) -> Result<(), Box<dyn std::error::Error>> {
pub async fn download_unarchive_distribution(src: &Url, dst: &String) -> Result<()> {
// Uses the object_store crate to download the distribution from the source to the destination path and untar and unzip it
// The source is a URL to the distribution
// The destination path is the path to the directory where the distribution will be downloaded and extracted
Expand All @@ -23,4 +25,4 @@ pub async fn download_unarchive_distribution(src: &Url, dst: &String) -> Result<
let mut archive = Archive::new(tar_data);
archive.unpack(dst)?;
Ok(())
}
}
17 changes: 7 additions & 10 deletions rust/flyrs/src/executor.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt::{Display, Formatter};
use anyhow::{bail, Result};

use clap::Parser;
use pyo3::prelude::*;
Expand Down Expand Up @@ -35,10 +36,7 @@ pub struct ExecutorArgs {

impl Display for ExecutorArgs {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "ExecutorArgs {{ inputs: {}, output_prefix: {}, test: {}, raw_output_data_prefix: {}, resolver: {}, resolver_args: {:?}, checkpoint_path: {:?}, prev_checkpoint: {:?}, dynamic_addl_distro: {:?}, dynamic_dest_dir: {:?} }}",
self.inputs, self.output_prefix, self.test, self.raw_output_data_prefix,
self.resolver, self.resolver_args, self.checkpoint_path, self.prev_checkpoint,
self.dynamic_addl_distro, self.dynamic_dest_dir)
write!(f, "{:?}", self)
}
}

Expand Down Expand Up @@ -69,9 +67,9 @@ fn debug_python_setup(py: Python) {
}

#[tracing::instrument(err)]
pub async fn execute_task(args: &ExecutorArgs) -> Result<(), Box<dyn std::error::Error>>{
pub async fn execute_task(args: &ExecutorArgs) -> Result<()> {
pyo3::prepare_freethreaded_python();
let _ = Python::with_gil(|py| -> Result<(), Box<dyn std::error::Error>> {
let _ = Python::with_gil(|py| -> Result<()> {
debug_python_setup(py);
let entrypoint = PyModule::import_bound(py, "flytekit.bin.entrypoint").unwrap();

Expand All @@ -98,8 +96,7 @@ pub async fn execute_task(args: &ExecutorArgs) -> Result<(), Box<dyn std::error:
let result = entrypoint.call_method1("_execute_task", args).unwrap();

if !result.is_none() {
debug!("Task failed");
return Err("Task failed".into());
bail!("Task failed");
}
debug!("Task completed");
Ok(())
Expand All @@ -109,11 +106,11 @@ pub async fn execute_task(args: &ExecutorArgs) -> Result<(), Box<dyn std::error:
}

#[tracing::instrument(level = Level::DEBUG, err)]
pub async fn run(executor_args: &ExecutorArgs) -> Result<(), Box<dyn std::error::Error>> {
pub async fn run(executor_args: &ExecutorArgs) -> Result<()> {
if executor_args.dynamic_addl_distro.is_some() {
info!("Found Dynamic distro {:?}", executor_args.dynamic_addl_distro);
if executor_args.dynamic_dest_dir.is_none() {
return Err("Dynamic distro requires a destination directory".into());
bail!("Dynamic distro requires a destination directory");
}
let src_url = url::Url::parse(executor_args.dynamic_addl_distro.clone().unwrap().as_str())?;
download_unarchive_distribution(&src_url, &executor_args.dynamic_dest_dir.clone().unwrap()).await?;
Expand Down
10 changes: 9 additions & 1 deletion rust/flyrs/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
use anyhow::{bail, Result};
use clap::Parser;
<<<<<<< Updated upstream
use tracing::info;
use tokio;
use tokio::runtime::Builder;
use tracing_subscriber;
=======
use env_logger::Env;
use log::info;
use tokio;
use tokio::runtime::Builder;
>>>>>>> Stashed changes

mod executor;
mod distribution;


fn main() -> Result<(), Box<dyn std::error::Error>> {
fn main() -> Result<()> {

tracing_subscriber::fmt::init();

Expand Down

0 comments on commit bdf27de

Please sign in to comment.