From 89ab6a162c8ca4143e347232fd80748fcc5acc83 Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Fri, 4 Aug 2023 16:27:57 -0400 Subject: [PATCH 1/3] running black across codebase --- .../parlai_test_script.py | 7 +- examples/remote_procedure/mnist/model.py | 4 +- .../simple_static_task/examine_results.py | 4 +- .../static_test_prolific_script.py | 10 +- hydra_plugins/mephisto_path_plugin.py | 4 +- .../_subcomponents/agent_state.py | 4 +- .../_subcomponents/task_runner.py | 20 +- .../architects/channels/websocket_channel.py | 16 +- .../architects/ec2/cleanup_ec2_server_all.py | 34 ++-- .../ec2/cleanup_ec2_server_by_name.py | 3 +- .../architects/ec2/ec2_architect.py | 20 +- .../architects/ec2/ec2_helpers.py | 42 ++-- .../architects/ec2/prepare_ec2_servers.py | 8 +- .../architects/heroku_architect.py | 52 ++--- .../architects/local_architect.py | 4 +- .../abstractions/architects/mock_architect.py | 4 +- .../architects/router/build_router.py | 7 +- .../router/flask/mephisto_flask_blueprint.py | 8 +- mephisto/abstractions/blueprint.py | 44 +--- .../abstract/static_task/static_blueprint.py | 16 +- .../static_task/static_task_runner.py | 4 +- .../blueprints/mixins/onboarding_required.py | 16 +- .../blueprints/mixins/screen_task_required.py | 13 +- .../blueprints/mixins/use_gold_unit.py | 48 +---- .../blueprints/mock/mock_blueprint.py | 8 +- .../blueprints/mock/mock_task_runner.py | 8 +- .../parlai_chat/parlai_chat_blueprint.py | 39 +--- .../parlai_chat/parlai_chat_task_builder.py | 4 +- .../parlai_chat/parlai_chat_task_runner.py | 13 +- .../remote_procedure_blueprint.py | 24 +-- .../remote_procedure_task_runner.py | 12 +- .../static_html_task/static_html_blueprint.py | 22 +- .../static_react_blueprint.py | 8 +- mephisto/abstractions/crowd_provider.py | 4 +- mephisto/abstractions/database.py | 63 ++---- .../abstractions/databases/local_database.py | 91 +++------ .../providers/mock/mock_provider.py | 4 +- .../providers/mock/mock_requester.py | 4 +- .../abstractions/providers/mock/mock_unit.py | 13 +- .../providers/mturk/mturk_agent.py | 16 +- .../providers/mturk/mturk_datastore.py | 16 +- .../providers/mturk/mturk_provider.py | 19 +- .../providers/mturk/mturk_requester.py | 8 +- .../providers/mturk/mturk_unit.py | 28 +-- .../providers/mturk/mturk_utils.py | 85 ++------ .../providers/mturk/mturk_worker.py | 56 ++---- .../providers/mturk/utils/script_utils.py | 4 +- .../mturk_sandbox/sandbox_mturk_agent.py | 8 +- .../mturk_sandbox/sandbox_mturk_requester.py | 8 +- .../mturk_sandbox/sandbox_mturk_unit.py | 8 +- .../mturk_sandbox/sandbox_mturk_worker.py | 8 +- .../prolific/api/base_api_resource.py | 17 +- .../providers/prolific/api/bonuses.py | 4 +- .../providers/prolific/api/constants.py | 1 + .../prolific/api/data_models/base_model.py | 12 +- .../api/data_models/bonus_payments.py | 9 +- .../data_models/eligibility_requirement.py | 69 +++---- .../prolific/api/data_models/message.py | 25 +-- .../prolific/api/data_models/participant.py | 12 +- .../api/data_models/participant_group.py | 39 ++-- .../prolific/api/data_models/project.py | 43 ++-- .../prolific/api/data_models/study.py | 189 +++++++++--------- .../prolific/api/data_models/submission.py | 56 +++--- .../prolific/api/data_models/user.py | 135 ++++++------- .../prolific/api/data_models/workspace.py | 53 ++--- .../api/data_models/workspace_balance.py | 37 ++-- .../age_range_eligibility_requirement.py | 9 +- ...pproval_numbers_eligibility_requirement.py | 6 +- .../approval_rate_eligibility_requirement.py | 5 +- .../base_eligibility_requirement.py | 15 +- ...stom_black_list_eligibility_requirement.py | 5 +- ...stom_white_list_eligibility_requirement.py | 5 +- .../joined_before_eligibility_requirement.py | 5 +- ...rticipant_group_eligibility_requirement.py | 15 +- .../prolific/api/eligibility_requirements.py | 8 +- .../providers/prolific/api/exceptions.py | 11 +- .../providers/prolific/api/messages.py | 16 +- .../prolific/api/participant_groups.py | 30 +-- .../providers/prolific/api/projects.py | 9 +- .../providers/prolific/api/studies.py | 26 +-- .../providers/prolific/api/submissions.py | 14 +- .../providers/prolific/api/users.py | 2 +- .../providers/prolific/prolific_agent.py | 28 +-- .../providers/prolific/prolific_datastore.py | 54 +++-- .../providers/prolific/prolific_provider.py | 87 ++++---- .../providers/prolific/prolific_requester.py | 16 +- .../providers/prolific/prolific_unit.py | 70 +++---- .../providers/prolific/prolific_utils.py | 66 +++--- .../providers/prolific/prolific_worker.py | 117 ++++++----- .../abstractions/test/architect_tester.py | 8 +- .../abstractions/test/blueprint_tester.py | 8 +- .../test/crowd_provider_tester.py | 8 +- .../test/data_model_database_tester.py | 16 +- mephisto/client/api.py | 19 +- mephisto/client/cli.py | 39 +--- mephisto/client/cli_commands.py | 40 +--- mephisto/client/full/server.py | 12 +- mephisto/client/review/review_server.py | 39 +--- mephisto/data_model/agent.py | 62 ++---- mephisto/data_model/assignment.py | 19 +- mephisto/data_model/requester.py | 8 +- mephisto/data_model/task_run.py | 31 +-- mephisto/data_model/unit.py | 8 +- mephisto/data_model/worker.py | 24 +-- mephisto/operations/client_io_handler.py | 38 ++-- mephisto/operations/config_handler.py | 4 +- mephisto/operations/datatypes.py | 4 +- mephisto/operations/operator.py | 42 +--- mephisto/operations/registry.py | 28 +-- mephisto/operations/task_launcher.py | 13 +- mephisto/operations/worker_pool.py | 113 +++-------- .../gh_actions/auto_generate_blueprint.py | 5 +- .../gh_actions/auto_generate_provider.py | 4 +- .../local_db/load_data_to_mephisto_db.py | 4 +- .../scripts/local_db/remove_accepted_tip.py | 12 +- .../local_db/review_feedback_for_task.py | 4 +- .../scripts/local_db/review_tips_for_task.py | 16 +- mephisto/scripts/mturk/cleanup.py | 11 +- .../scripts/mturk/identify_broken_units.py | 24 +-- mephisto/scripts/mturk/launch_makeup_hits.py | 10 +- .../mturk/print_outstanding_hit_status.py | 4 +- .../mturk/soft_block_workers_by_mturk_id.py | 4 +- mephisto/tools/data_browser.py | 16 +- mephisto/tools/examine_utils.py | 16 +- mephisto/tools/scripts.py | 19 +- mephisto/utils/metrics.py | 12 +- mephisto/utils/qualifications.py | 8 +- mephisto/utils/testing.py | 4 +- scripts/check_npm_package_versions.py | 4 +- scripts/sync_mephisto_task.py | 4 +- .../architects/test_local_architect.py | 9 +- .../blueprints/test_mixin_core.py | 52 ++--- .../blueprints/test_mock_blueprint.py | 4 +- .../mturk_sandbox/test_mturk_provider.py | 36 +--- .../providers/prolific/test_prolific_utils.py | 95 +++++---- test/core/test_live_runs.py | 119 +++-------- test/core/test_operator.py | 31 +-- test/core/test_task_launcher.py | 8 +- test/test_data_model.py | 4 +- test/tools/test_data_brower.py | 4 +- test/utils/prolific_api/test_data_models.py | 32 +-- 141 files changed, 1279 insertions(+), 2172 deletions(-) diff --git a/examples/parlai_chat_task_demo/parlai_test_script.py b/examples/parlai_chat_task_demo/parlai_test_script.py index b7b09eca7..c08c075d6 100644 --- a/examples/parlai_chat_task_demo/parlai_test_script.py +++ b/examples/parlai_chat_task_demo/parlai_test_script.py @@ -27,8 +27,7 @@ class ParlAITaskConfig(build_default_task_config("example")): # type: ignore turn_timeout: int = field( default=300, metadata={ - "help": "Maximum response time before kicking " - "a worker out, default 300 seconds" + "help": "Maximum response time before kicking " "a worker out, default 300 seconds" }, ) @@ -46,9 +45,7 @@ def main(operator: "Operator", cfg: DictConfig) -> None: ) world_opt["send_task_data"] = True - shared_state = SharedParlAITaskState( - world_opt=world_opt, onboarding_world_opt=world_opt - ) + shared_state = SharedParlAITaskState(world_opt=world_opt, onboarding_world_opt=world_opt) operator.launch_task_run(cfg.mephisto, shared_state) operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30) diff --git a/examples/remote_procedure/mnist/model.py b/examples/remote_procedure/mnist/model.py index 4edb9ea98..adc4a19ed 100644 --- a/examples/remote_procedure/mnist/model.py +++ b/examples/remote_procedure/mnist/model.py @@ -11,9 +11,7 @@ from collections import OrderedDict import torch.utils.model_zoo as model_zoo -model_urls = { - "mnist": "http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/mnist-b07bb66b.pth" -} +model_urls = {"mnist": "http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/mnist-b07bb66b.pth"} class MLP(nn.Module): diff --git a/examples/simple_static_task/examine_results.py b/examples/simple_static_task/examine_results.py index 39eaaedf9..9614c2902 100644 --- a/examples/simple_static_task/examine_results.py +++ b/examples/simple_static_task/examine_results.py @@ -25,7 +25,9 @@ def format_for_printing_data(data): ) inputs = contents["inputs"] - inputs_string = f"Character: {inputs['character_name']}\nDescription: {inputs['character_description']}\n" + inputs_string = ( + f"Character: {inputs['character_name']}\nDescription: {inputs['character_description']}\n" + ) outputs = contents["outputs"] output_string = f" Rating: {outputs['rating']}\n" diff --git a/examples/simple_static_task/static_test_prolific_script.py b/examples/simple_static_task/static_test_prolific_script.py index 16e414b71..e2ef6ec87 100644 --- a/examples/simple_static_task/static_test_prolific_script.py +++ b/examples/simple_static_task/static_test_prolific_script.py @@ -14,7 +14,7 @@ from mephisto.utils.qualifications import make_qualification_dict -@task_script(default_config_file='prolific_example') +@task_script(default_config_file="prolific_example") def main(operator, cfg: DictConfig) -> None: shared_state = SharedStaticTaskState() @@ -27,9 +27,9 @@ def main(operator, cfg: DictConfig) -> None: # Note that we'll prefix names with a customary `web.eligibility.models.` later in the code shared_state.prolific_specific_qualifications = [ { - 'name': 'AgeRangeEligibilityRequirement', - 'min_age': 18, - 'max_age': 100, + "name": "AgeRangeEligibilityRequirement", + "min_age": 18, + "max_age": 100, }, ] @@ -37,5 +37,5 @@ def main(operator, cfg: DictConfig) -> None: operator.wait_for_runs_then_shutdown(skip_input=True, log_rate=30) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/hydra_plugins/mephisto_path_plugin.py b/hydra_plugins/mephisto_path_plugin.py index 8f2e811d7..722ba0417 100644 --- a/hydra_plugins/mephisto_path_plugin.py +++ b/hydra_plugins/mephisto_path_plugin.py @@ -19,6 +19,4 @@ def manipulate_search_path(self, search_path: ConfigSearchPath) -> None: profile_path_user = os.path.join(DEFAULT_CONFIG_FOLDER, "hydra_configs") search_path.append(provider="mephisto-profiles", path=f"file://{profile_path}") - search_path.append( - provider="mephisto-profiles-user", path=f"file://{profile_path_user}" - ) + search_path.append(provider="mephisto-profiles-user", path=f"file://{profile_path_user}") diff --git a/mephisto/abstractions/_subcomponents/agent_state.py b/mephisto/abstractions/_subcomponents/agent_state.py index ef0f63aef..b7036e70d 100644 --- a/mephisto/abstractions/_subcomponents/agent_state.py +++ b/mephisto/abstractions/_subcomponents/agent_state.py @@ -85,9 +85,7 @@ def __new__(cls, agent: Union["Agent", "OnboardingAgent"]) -> "AgentState": if isinstance(agent, Agent): correct_class = get_blueprint_from_type(agent.task_type).AgentStateClass else: - correct_class = get_blueprint_from_type( - agent.task_type - ).OnboardingAgentStateClass + correct_class = get_blueprint_from_type(agent.task_type).OnboardingAgentStateClass return super().__new__(correct_class) else: # We are constructing another instance directly diff --git a/mephisto/abstractions/_subcomponents/task_runner.py b/mephisto/abstractions/_subcomponents/task_runner.py index db925bbd8..b0c939a78 100644 --- a/mephisto/abstractions/_subcomponents/task_runner.py +++ b/mephisto/abstractions/_subcomponents/task_runner.py @@ -110,9 +110,7 @@ class TaskRunner(ABC): passing agents through a task. """ - def __init__( - self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" - ): + def __init__(self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"): self.args = args self.shared_state = shared_state self.task_run = task_run @@ -172,9 +170,7 @@ def _launch_and_run_onboarding( """Supervise the completion of an onboarding""" with ONGOING_THREAD_COUNT.labels( thread_type="onboarding" - ).track_inprogress(), EXECUTION_DURATION_SECONDS.labels( - thread_type="onboarding" - ).time(): + ).track_inprogress(), EXECUTION_DURATION_SECONDS.labels(thread_type="onboarding").time(): live_run = onboarding_agent.get_live_run() onboarding_id = onboarding_agent.get_agent_id() logger.debug(f"Launching onboarding for {onboarding_agent}") @@ -202,9 +198,7 @@ def _launch_and_run_onboarding( if onboarding_agent.get_status() == AgentState.STATUS_WAITING: # The agent completed the onboarding task async def register_then_cleanup(): - await live_run.worker_pool.register_agent_from_onboarding( - onboarding_agent - ) + await live_run.worker_pool.register_agent_from_onboarding(onboarding_agent) await cleanup_after() live_run.loop_wrap.execute_coro(register_then_cleanup()) @@ -267,9 +261,7 @@ def _launch_and_run_unit( """Supervise the completion of a unit thread""" with ONGOING_THREAD_COUNT.labels( thread_type="unit" - ).track_inprogress(), EXECUTION_DURATION_SECONDS.labels( - thread_type="unit" - ).time(): + ).track_inprogress(), EXECUTION_DURATION_SECONDS.labels(thread_type="unit").time(): try: self.run_unit(unit, agent) except ( @@ -348,9 +340,7 @@ def _launch_and_run_assignment( """Supervise the completion of an assignment thread""" with ONGOING_THREAD_COUNT.labels( thread_type="assignment" - ).track_inprogress(), EXECUTION_DURATION_SECONDS.labels( - thread_type="assignment" - ).time(): + ).track_inprogress(), EXECUTION_DURATION_SECONDS.labels(thread_type="assignment").time(): try: self.run_assignment(assignment, agents) except ( diff --git a/mephisto/abstractions/architects/channels/websocket_channel.py b/mephisto/abstractions/architects/channels/websocket_channel.py index 0eac5ffbc..b2ae43fc8 100644 --- a/mephisto/abstractions/architects/channels/websocket_channel.py +++ b/mephisto/abstractions/architects/channels/websocket_channel.py @@ -98,9 +98,7 @@ async def on_error(error): if hasattr(error, "errno"): if error.errno == errno.ECONNREFUSED: # TODO(CLEAN) replace with channel exception - raise Exception( - f"Socket {self.socket_url} refused connection, cancelling" - ) + raise Exception(f"Socket {self.socket_url} refused connection, cancelling") else: logger.info(f"Socket logged error: {error}") @@ -132,9 +130,7 @@ async def run_socket(): # Outer loop allows reconnects while not self._is_closed: try: - async with websockets.connect( - self.socket_url, open_timeout=30 - ) as websocket: + async with websockets.connect(self.socket_url, open_timeout=30) as websocket: # Inner loop recieves messages until closed self.socket = websocket on_socket_open() @@ -160,15 +156,11 @@ async def run_socket(): self.on_catastrophic_disconnect(self.channel_id) return except OSError as e: - logger.error( - f"Unhandled OSError exception in socket {e}, attempting restart" - ) + logger.error(f"Unhandled OSError exception in socket {e}, attempting restart") await asyncio.sleep(0.2) except websockets.exceptions.InvalidStatusCode as e: if self._retries == 0: - raise ConnectionRefusedError( - "Could not connect after retries" - ) from e + raise ConnectionRefusedError("Could not connect after retries") from e curr_retry = MAX_RETRIES - self._retries logger.exception( f"Status code error {repr(e)}, attempting retry {curr_retry}", diff --git a/mephisto/abstractions/architects/ec2/cleanup_ec2_server_all.py b/mephisto/abstractions/architects/ec2/cleanup_ec2_server_all.py index 807f29ff9..2b648ddb7 100644 --- a/mephisto/abstractions/architects/ec2/cleanup_ec2_server_all.py +++ b/mephisto/abstractions/architects/ec2/cleanup_ec2_server_all.py @@ -27,56 +27,50 @@ def main(): all_server_names = [ os.path.splitext(s)[0] for s in os.listdir(DEFAULT_SERVER_DETAIL_LOCATION) - if s.endswith('json') and s not in EXCLUDE_FILES_IN_SERVER_DIR + if s.endswith("json") and s not in EXCLUDE_FILES_IN_SERVER_DIR ] n_names = len(all_server_names) if not n_names: - logger.info('No servers found to clean up') + logger.info("No servers found to clean up") return logger.info(f'Found {n_names} server names: {", ".join(all_server_names)}') - confirm = input( - f'Are you sure you want to remove the {n_names} found servers? [y/N]\n' - f'>> ' - ) - if confirm != 'y': + confirm = input(f"Are you sure you want to remove the {n_names} found servers? [y/N]\n" f">> ") + if confirm != "y": return # Get EC2 user role - iam_role_name = input( - 'Please enter local profile name for IAM role\n' - '>> ' - ) - logger.info(f'Removing {n_names} servers...') + iam_role_name = input("Please enter local profile name for IAM role\n" ">> ") + logger.info(f"Removing {n_names} servers...") # Cleanup local server JSON files, and remove related EC2 infra skipped_names = [] for i, server_name in enumerate(all_server_names): _name = f'"{server_name}"' - logger.info(f'{i+1}/{n_names} Removing {_name}...') + logger.info(f"{i+1}/{n_names} Removing {_name}...") - session = boto3.Session(profile_name=iam_role_name, region_name='us-east-2') + session = boto3.Session(profile_name=iam_role_name, region_name="us-east-2") try: skipped_names.append(_name) ec2_helpers.remove_instance_and_cleanup(session, server_name) - logger.debug(f'...{_name} - successfully removed') + logger.debug(f"...{_name} - successfully removed") skipped_names.remove(_name) except botocore.exceptions.ClientError as e: - logger.warning(f'...{_name} - could not be removed: {e}') + logger.warning(f"...{_name} - could not be removed: {e}") except json.decoder.JSONDecodeError as e: - logger.warning(f'...{_name} - could not read JSON config: {e}') + logger.warning(f"...{_name} - could not read JSON config: {e}") except Exception as e: - logger.warning(f'...{_name} - encountered error: {e}') + logger.warning(f"...{_name} - encountered error: {e}") if skipped_names: logger.info( f'Could not remove {len(skipped_names)}/{n_names} servers: {", ".join(skipped_names)}' ) else: - logger.info(f'Successfully removed {n_names} servers') + logger.info(f"Successfully removed {n_names} servers") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mephisto/abstractions/architects/ec2/cleanup_ec2_server_by_name.py b/mephisto/abstractions/architects/ec2/cleanup_ec2_server_by_name.py index 69d41e2a4..63ee04f5c 100644 --- a/mephisto/abstractions/architects/ec2/cleanup_ec2_server_by_name.py +++ b/mephisto/abstractions/architects/ec2/cleanup_ec2_server_by_name.py @@ -33,8 +33,7 @@ def main(): f"Please enter server name you want to clean up (existing servers: {all_server_names})\n>> " ) assert ( - os.path.join(DEFAULT_SERVER_DETAIL_LOCATION, f"{server_name}.json") - != DEFAULT_FALLBACK_FILE + os.path.join(DEFAULT_SERVER_DETAIL_LOCATION, f"{server_name}.json") != DEFAULT_FALLBACK_FILE ), "This is going to completely delete the fallback server for your EC2 architect." assert server_name in all_server_names, f"{server_name} does not exist" diff --git a/mephisto/abstractions/architects/ec2/ec2_architect.py b/mephisto/abstractions/architects/ec2/ec2_architect.py index 3695b24e4..9b20c87b5 100644 --- a/mephisto/abstractions/architects/ec2/ec2_architect.py +++ b/mephisto/abstractions/architects/ec2/ec2_architect.py @@ -67,9 +67,7 @@ class EC2ArchitectArgs(ArchitectArgs): """Additional arguments for configuring a heroku architect""" _architect_type: str = ARCHITECT_TYPE - instance_type: str = field( - default="t2.micro", metadata={"help": "Instance type to run router"} - ) + instance_type: str = field(default="t2.micro", metadata={"help": "Instance type to run router"}) subdomain: str = field( default="${mephisto.task.task_name}", metadata={"help": "Subdomain name for routing"}, @@ -122,9 +120,7 @@ def __init__( self.build_dir = build_dir_root self.server_detail_path = self._get_detail_path(self.subdomain) - self.session = boto3.Session( - profile_name=self.profile_name, region_name="us-east-2" - ) + self.session = boto3.Session(profile_name=self.profile_name, region_name="us-east-2") self.server_dir: Optional[str] = None self.server_id: Optional[str] = None @@ -235,17 +231,13 @@ def assert_task_args(cls, args: DictConfig, shared_state: "SharedTaskState"): assert key in fallback_details, f"Fallback file missing required key {key}" session = boto3.Session(profile_name=profile_name, region_name="us-east-2") - is_new_rule = ec2_helpers.rule_is_new( - session, subdomain, fallback_details["listener_arn"] - ) + is_new_rule = ec2_helpers.rule_is_new(session, subdomain, fallback_details["listener_arn"]) if args.architect._deploy_type in ["retain", "standard"]: assert ( is_new_rule ), "Rule was not new, existing subdomain found registered to the listener. Check on AWS." else: - assert ( - not is_new_rule - ), "Rule did not exist, Clean up and redeploy a new retain server." + assert not is_new_rule, "Rule did not exist, Clean up and redeploy a new retain server." def __get_build_directory(self) -> str: """ @@ -274,9 +266,7 @@ def __compile_server(self) -> str: setup_path = os.path.join(SCRIPTS_DIRECTORY, self.server_type) setup_dest = os.path.join(server_build_root, "setup") shutil.copytree(setup_path, setup_dest) - possible_node_modules = os.path.join( - server_build_root, "router", "node_modules" - ) + possible_node_modules = os.path.join(server_build_root, "router", "node_modules") if os.path.exists(possible_node_modules): shutil.rmtree(possible_node_modules) return server_dir diff --git a/mephisto/abstractions/architects/ec2/ec2_helpers.py b/mephisto/abstractions/architects/ec2/ec2_helpers.py index bbbbffef5..f4bcbabd4 100644 --- a/mephisto/abstractions/architects/ec2/ec2_helpers.py +++ b/mephisto/abstractions/architects/ec2/ec2_helpers.py @@ -28,9 +28,7 @@ if TYPE_CHECKING: from omegaconf import DictConfig # type: ignore -botoconfig = Config( - region_name="us-east-2", retries={"max_attempts": 10, "mode": "standard"} -) +botoconfig = Config(region_name="us-east-2", retries={"max_attempts": 10, "mode": "standard"}) DEFAULT_AMI_ID = "ami-0f19d220602031aed" AMI_DEFAULT_USER = "ec2-user" @@ -70,9 +68,7 @@ def check_aws_credentials(profile_name: str) -> bool: return False -def setup_ec2_credentials( - profile_name: str, register_args: Optional["DictConfig"] = None -) -> bool: +def setup_ec2_credentials(profile_name: str, register_args: Optional["DictConfig"] = None) -> bool: return setup_aws_credentials(profile_name, register_args) @@ -204,9 +200,7 @@ def get_certificate(session: boto3.Session, domain_name: str) -> Dict[str, str]: details = client.describe_certificate( CertificateArn=certificate_arn, ) - return_data = details["Certificate"]["DomainValidationOptions"][0][ - "ResourceRecord" - ] + return_data = details["Certificate"]["DomainValidationOptions"][0]["ResourceRecord"] return_data["arn"] = certificate_arn return return_data except KeyError: @@ -234,9 +228,9 @@ def register_zone_records( """ # Get details about the load balancer ec2_client = session.client("elbv2") - balancer = ec2_client.describe_load_balancers( - LoadBalancerArns=[load_balancer_arn], - )["LoadBalancers"][0] + balancer = ec2_client.describe_load_balancers(LoadBalancerArns=[load_balancer_arn],)[ + "LoadBalancers" + ][0] load_balancer_dns = balancer["DNSName"] load_balancer_zone = balancer["CanonicalHostedZoneId"] @@ -905,20 +899,14 @@ def try_server_push(subprocess_args: List[str], retries=5, sleep_time=10.0): """ while retries > 0: try: - subprocess.check_call( - subprocess_args, env=dict(os.environ, SSH_AUTH_SOCK="") - ) + subprocess.check_call(subprocess_args, env=dict(os.environ, SSH_AUTH_SOCK="")) return except subprocess.CalledProcessError: retries -= 1 sleep_time *= 1.5 - logger.info( - f"Timed out trying to push to server. Retries remaining: {retries}" - ) + logger.info(f"Timed out trying to push to server. Retries remaining: {retries}") time.sleep(sleep_time) - raise Exception( - "Could not successfully push to the ec2 instance. See log for errors." - ) + raise Exception("Could not successfully push to the ec2 instance. See log for errors.") def deploy_fallback_server( @@ -932,9 +920,7 @@ def deploy_fallback_server( return True if successful """ client = session.client("ec2") - server_host, allocation_id, association_id = get_instance_address( - session, instance_id - ) + server_host, allocation_id, association_id = get_instance_address(session, instance_id) try: keypair_file = os.path.join(DEFAULT_KEY_PAIR_DIRECTORY, f"{key_pair}.pem") password_file_name = os.path.join(FALLBACK_SERVER_LOC, f"access_key.txt") @@ -985,9 +971,7 @@ def deploy_to_routing_server( push_directory: str, ) -> bool: client = session.client("ec2") - server_host, allocation_id, association_id = get_instance_address( - session, instance_id - ) + server_host, allocation_id, association_id = get_instance_address(session, instance_id) keypair_file = os.path.join(DEFAULT_KEY_PAIR_DIRECTORY, f"{key_pair}.pem") print("Uploading files to server, then attempting to run") @@ -1066,9 +1050,7 @@ def remove_instance_and_cleanup( Cleanup for a launched server, removing the redirect rule clearing the target group, and then shutting down the instance. """ - server_detail_path = os.path.join( - DEFAULT_SERVER_DETAIL_LOCATION, f"{server_name}.json" - ) + server_detail_path = os.path.join(DEFAULT_SERVER_DETAIL_LOCATION, f"{server_name}.json") with open(server_detail_path, "r") as detail_file: details = json.load(detail_file) diff --git a/mephisto/abstractions/architects/ec2/prepare_ec2_servers.py b/mephisto/abstractions/architects/ec2/prepare_ec2_servers.py index 4d59dab2e..74dca5268 100644 --- a/mephisto/abstractions/architects/ec2/prepare_ec2_servers.py +++ b/mephisto/abstractions/architects/ec2/prepare_ec2_servers.py @@ -194,9 +194,7 @@ def launch_ec2_fallback( print(f"Using existing listener {listener_arn}") # Finally, deploy the fallback server contents: - ec2_helpers.deploy_fallback_server( - session, instance_id, key_pair_name, access_logs_key - ) + ec2_helpers.deploy_fallback_server(session, instance_id, key_pair_name, access_logs_key) existing_details["access_logs_key"] = access_logs_key update_details(saved_details_file, existing_details) @@ -210,9 +208,7 @@ def main(): domain_name = input("Please provide the domain name you will be using\n>> ") ssh_ip_block = input("Provide the CIDR IP block for ssh access\n>> ") - access_logs_key = input( - "Please provide a key password to use for accessing server logs\n>> " - ) + access_logs_key = input("Please provide a key password to use for accessing server logs\n>> ") launch_ec2_fallback(iam_role_name, domain_name, ssh_ip_block, access_logs_key) diff --git a/mephisto/abstractions/architects/heroku_architect.py b/mephisto/abstractions/architects/heroku_architect.py index a4397c87d..5b609c660 100644 --- a/mephisto/abstractions/architects/heroku_architect.py +++ b/mephisto/abstractions/architects/heroku_architect.py @@ -44,9 +44,7 @@ USER_NAME = getpass.getuser() HEROKU_SERVER_BUILD_DIRECTORY = "heroku_server" -HEROKU_CLIENT_URL = ( - "https://cli-assets.heroku.com/heroku-cli/channels/stable/heroku-cli" -) +HEROKU_CLIENT_URL = "https://cli-assets.heroku.com/heroku-cli/channels/stable/heroku-cli" HEROKU_WAIT_TIME = 3 @@ -59,9 +57,7 @@ class HerokuArchitectArgs(ArchitectArgs): """Additional arguments for configuring a heroku architect""" _architect_type: str = ARCHITECT_TYPE - use_hobby: bool = field( - default=False, metadata={"help": "Launch on the Heroku Hobby tier"} - ) + use_hobby: bool = field(default=False, metadata={"help": "Launch on the Heroku Hobby tier"}) heroku_team: Optional[str] = field( default=MISSING, metadata={"help": "Heroku team to use for this launch"} ) @@ -70,9 +66,7 @@ class HerokuArchitectArgs(ArchitectArgs): ) heroku_config_args: Dict[str, str] = field( default_factory=dict, - metadata={ - "help": "str:str dict containing all heroku config variables to set for the app" - }, + metadata={"help": "str:str dict containing all heroku config variables to set for the app"}, ) @@ -114,9 +108,7 @@ def __init__( self.heroku_config_args = dict(args.architect.heroku_config_args) # Cache-able parameters - self.__heroku_app_name: Optional[str] = args.architect.get( - "heroku_app_name", None - ) + self.__heroku_app_name: Optional[str] = args.architect.get("heroku_app_name", None) self.__heroku_app_url: Optional[str] = None self.__heroku_executable_path: Optional[str] = None self.__heroku_user_identifier: Optional[str] = None @@ -154,9 +146,7 @@ def download_file(self, target_filename: str, save_dir: str) -> None: Heroku architects need to download the file """ heroku_app_name = self.__get_app_name() - target_url = ( - f"https://{heroku_app_name}.herokuapp.com/download_file/{target_filename}" - ) + target_url = f"https://{heroku_app_name}.herokuapp.com/download_file/{target_filename}" dest_path = os.path.join(save_dir, target_filename) r = requests.get(target_url, stream=True) @@ -178,9 +168,7 @@ def assert_task_args(cls, args: DictConfig, shared_state: "SharedTaskState"): """ heroku_executable_path = HerokuArchitect.get_heroku_client_path() try: - output = subprocess.check_output( - shlex.split(heroku_executable_path + " auth:whoami") - ) + output = subprocess.check_output(shlex.split(heroku_executable_path + " auth:whoami")) except subprocess.CalledProcessError: raise Exception( "A free Heroku account is required for launching tasks via " @@ -218,9 +206,7 @@ def get_heroku_client_path() -> str: bit_architecture = "x86" # Find existing heroku files to use - existing_heroku_directory_names = glob.glob( - os.path.join(HEROKU_TMP_DIR, "heroku-cli-*") - ) + existing_heroku_directory_names = glob.glob(os.path.join(HEROKU_TMP_DIR, "heroku-cli-*")) if len(existing_heroku_directory_names) == 0: print("Getting heroku") if os.path.exists(os.path.join(HEROKU_TMP_DIR, "heroku.tar.gz")): @@ -281,10 +267,7 @@ def __get_heroku_client(self) -> Tuple[str, str]: """ Get an authorized heroku client path and authorization token """ - if ( - self.__heroku_executable_path is None - or self.__heroku_user_identifier is None - ): + if self.__heroku_executable_path is None or self.__heroku_user_identifier is None: ( heroku_executable_path, heroku_user_identifier, @@ -369,9 +352,7 @@ def __setup_heroku_server(self) -> str: ) else: subprocess.check_output( - shlex.split( - "{} create {}".format(heroku_executable_path, heroku_app_name) - ) + shlex.split("{} create {}".format(heroku_executable_path, heroku_app_name)) ) self.created = True except subprocess.CalledProcessError as e: # User has too many apps? @@ -394,9 +375,7 @@ def __setup_heroku_server(self) -> str: try: subprocess.check_output( shlex.split( - "{} features:enable http-session-affinity".format( - heroku_executable_path - ) + "{} features:enable http-session-affinity".format(heroku_executable_path) ) ) except subprocess.CalledProcessError: # Already enabled WebSockets @@ -419,14 +398,10 @@ def __setup_heroku_server(self) -> str: # commit and push to the heroku server sh.git(shlex.split(f"-C {heroku_server_directory_path} add -A")) sh.git(shlex.split(f'-C {heroku_server_directory_path} commit -m "app"')) - sh.git( - shlex.split(f"-C {heroku_server_directory_path} push -f heroku {branch}") - ) + sh.git(shlex.split(f"-C {heroku_server_directory_path} push -f heroku {branch}")) os.chdir(heroku_server_directory_path) - subprocess.check_output( - shlex.split("{} ps:scale web=1".format(heroku_executable_path)) - ) + subprocess.check_output(shlex.split("{} ps:scale web=1".format(heroku_executable_path))) if self.args.architect.use_hobby is True: try: @@ -437,8 +412,7 @@ def __setup_heroku_server(self) -> str: self.__delete_heroku_server() sh.rm(shlex.split("-rf {}".format(heroku_server_directory_path))) raise Exception( - "Server launched with hobby flag but account cannot create " - "hobby servers." + "Server launched with hobby flag but account cannot create " "hobby servers." ) os.chdir(return_dir) diff --git a/mephisto/abstractions/architects/local_architect.py b/mephisto/abstractions/architects/local_architect.py index d394fc8d5..37e3db233 100644 --- a/mephisto/abstractions/architects/local_architect.py +++ b/mephisto/abstractions/architects/local_architect.py @@ -171,9 +171,7 @@ def deploy(self) -> str: host = self.hostname port = self.port if host is None: - host = input( - "Please enter the public server address, like https://hostname.com: " - ) + host = input("Please enter the public server address, like https://hostname.com: ") self.hostname = host if port is None: port = input("Please enter the port given above, likely 3000: ") diff --git a/mephisto/abstractions/architects/mock_architect.py b/mephisto/abstractions/architects/mock_architect.py index 1e8f6bf99..55c7e8f68 100644 --- a/mephisto/abstractions/architects/mock_architect.py +++ b/mephisto/abstractions/architects/mock_architect.py @@ -332,9 +332,7 @@ def download_file(self, target_filename: str, save_dir: str) -> None: def prepare(self) -> str: """Mark the preparation call""" self.prepared = True - built_dir = os.path.join( - self.build_dir, "mock_build_{}".format(self.task_run_id) - ) + built_dir = os.path.join(self.build_dir, "mock_build_{}".format(self.task_run_id)) os.makedirs(built_dir) return built_dir diff --git a/mephisto/abstractions/architects/router/build_router.py b/mephisto/abstractions/architects/router/build_router.py index 2199685e0..4a63ebaa3 100644 --- a/mephisto/abstractions/architects/router/build_router.py +++ b/mephisto/abstractions/architects/router/build_router.py @@ -48,8 +48,7 @@ def install_router_files() -> None: packages_installed = subprocess.call(["npm", "install"]) if packages_installed != 0: raise Exception( - "please make sure node is installed, otherwise view " - "the above error for more info." + "please make sure node is installed, otherwise view " "the above error for more info." ) os.chdir(return_dir) @@ -93,9 +92,7 @@ def build_router( shutil.copytree(server_source_directory_path, local_server_directory_path) # Copy the required wrap crowd source path - local_crowd_source_path = os.path.join( - local_server_directory_path, CROWD_SOURCE_PATH - ) + local_crowd_source_path = os.path.join(local_server_directory_path, CROWD_SOURCE_PATH) crowd_provider = task_run.get_provider() shutil.copy2(crowd_provider.get_wrapper_js_path(), local_crowd_source_path) diff --git a/mephisto/abstractions/architects/router/flask/mephisto_flask_blueprint.py b/mephisto/abstractions/architects/router/flask/mephisto_flask_blueprint.py index d4e01db6c..d96ce4e14 100644 --- a/mephisto/abstractions/architects/router/flask/mephisto_flask_blueprint.py +++ b/mephisto/abstractions/architects/router/flask/mephisto_flask_blueprint.py @@ -230,9 +230,7 @@ def _handle_get_agent_status(self, agent_status_packet: Dict[str, Any]) -> None: "subject_id": SYSTEM_CHANNEL_ID, "data": agent_statuses, "client_timestamp": agent_status_packet["server_timestamp"], - "router_incoming_timestamp": agent_status_packet[ - "router_incoming_timestamp" - ], + "router_incoming_timestamp": agent_status_packet["router_incoming_timestamp"], } self._handle_forward(packet) @@ -354,9 +352,7 @@ def on_close(self, reason: Any) -> None: agent.is_alive = False agent.disconnect_time = time.time() - def make_agent_request( - self, request_packet: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: + def make_agent_request(self, request_packet: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Make a request to the core Mephisto server, and then await the response""" request_id = request_packet["data"]["request_id"] diff --git a/mephisto/abstractions/blueprint.py b/mephisto/abstractions/blueprint.py index cc788c973..2e8efc982 100644 --- a/mephisto/abstractions/blueprint.py +++ b/mephisto/abstractions/blueprint.py @@ -45,9 +45,7 @@ class BlueprintArgs: _blueprint_type: str = MISSING block_qualification: str = field( default=MISSING, - metadata={ - "help": ("Specify the name of a qualification used to soft block workers.") - }, + metadata={"help": ("Specify the name of a qualification used to soft block workers.")}, ) tips_location: str = field( default=os.path.join(get_run_file_dir(), "assets/tips.csv"), @@ -68,9 +66,7 @@ class SharedTaskState: task_config: Dict[str, Any] = field( default_factory=dict, metadata={ - "help": ( - "Values to be included in the frontend MephistoTask.task_config object" - ), + "help": ("Values to be included in the frontend MephistoTask.task_config object"), "type": "Dict[str, Any]", "default": "{}", }, @@ -131,23 +127,17 @@ class BlueprintMixin(ABC): def extract_unique_mixins(blueprint_class: Type["Blueprint"]): """Return the unique mixin classes that are used in the given blueprint class""" mixin_subclasses = [ - clazz - for clazz in blueprint_class.mro() - if issubclass(clazz, BlueprintMixin) + clazz for clazz in blueprint_class.mro() if issubclass(clazz, BlueprintMixin) ] target_class: Union[Type["Blueprint"], Type["BlueprintMixin"]] = blueprint_class # Remove magic created with `mixin_args_and_state` while target_class.__name__ == "MixedInBlueprint": target_class = mixin_subclasses.pop(0) removed_locals = [ - clazz - for clazz in mixin_subclasses - if "MixedInBlueprint" not in clazz.__name__ + clazz for clazz in mixin_subclasses if "MixedInBlueprint" not in clazz.__name__ ] filtered_subclasses = set( - clazz - for clazz in removed_locals - if clazz != BlueprintMixin and clazz != target_class + clazz for clazz in removed_locals if clazz != BlueprintMixin and clazz != target_class ) # Remaining "Blueprints" should be dropped at this point. @@ -157,13 +147,9 @@ def extract_unique_mixins(blueprint_class: Type["Blueprint"]): # we also want to make sure that we don't double-count extensions of mixins, so remove classes that other classes are subclasses of def is_subclassed(clazz): - return True in [ - issubclass(x, clazz) and x != clazz for x in filtered_out_blueprints - ] + return True in [issubclass(x, clazz) and x != clazz for x in filtered_out_blueprints] - unique_subclasses = [ - clazz for clazz in filtered_out_blueprints if not is_subclassed(clazz) - ] + unique_subclasses = [clazz for clazz in filtered_out_blueprints if not is_subclassed(clazz)] return unique_subclasses @abstractmethod @@ -175,9 +161,7 @@ def init_mixin_config( @classmethod @abstractmethod - def assert_mixin_args( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ) -> None: + def assert_mixin_args(cls, args: "DictConfig", shared_state: "SharedTaskState") -> None: """Method to validate the incoming args and throw if something won't work""" raise NotImplementedError() @@ -190,9 +174,7 @@ def get_mixin_qualifications( raise NotImplementedError() @classmethod - def mixin_args_and_state( - mixin_cls: Type["BlueprintMixin"], target_cls: Type["Blueprint"] - ): + def mixin_args_and_state(mixin_cls: Type["BlueprintMixin"], target_cls: Type["Blueprint"]): """ Magic utility decorator that can be used to inject mixin configurations (BlueprintArgs and SharedTaskState) without the user needing to define new @@ -240,9 +222,7 @@ class Blueprint(ABC): SharedStateClass: ClassVar[Type["SharedTaskState"]] = SharedTaskState BLUEPRINT_TYPE: str - def __init__( - self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" - ): + def __init__(self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"): self.args = args self.shared_state = shared_state self.frontend_task_config = shared_state.task_config @@ -252,9 +232,7 @@ def __init__( clazz.init_mixin_config(self, task_run, args, shared_state) @classmethod - def get_required_qualifications( - cls, args: DictConfig, shared_state: "SharedTaskState" - ): + def get_required_qualifications(cls, args: DictConfig, shared_state: "SharedTaskState"): quals = [] for clazz in BlueprintMixin.extract_unique_mixins(cls): quals += clazz.get_mixin_qualifications(args, shared_state) diff --git a/mephisto/abstractions/blueprints/abstract/static_task/static_blueprint.py b/mephisto/abstractions/blueprints/abstract/static_task/static_blueprint.py index 777b2dc29..b8402874f 100644 --- a/mephisto/abstractions/blueprints/abstract/static_task/static_blueprint.py +++ b/mephisto/abstractions/blueprints/abstract/static_task/static_blueprint.py @@ -177,19 +177,13 @@ def assert_task_args(cls, args: DictConfig, shared_state: "SharedTaskState"): blue_args = args.blueprint if blue_args.get("data_csv", None) is not None: csv_file = os.path.expanduser(blue_args.data_csv) - assert os.path.exists( - csv_file - ), f"Provided csv file {csv_file} doesn't exist" + assert os.path.exists(csv_file), f"Provided csv file {csv_file} doesn't exist" elif blue_args.get("data_json", None) is not None: json_file = os.path.expanduser(blue_args.data_json) - assert os.path.exists( - json_file - ), f"Provided JSON file {json_file} doesn't exist" + assert os.path.exists(json_file), f"Provided JSON file {json_file} doesn't exist" elif blue_args.get("data_jsonl", None) is not None: jsonl_file = os.path.expanduser(blue_args.data_jsonl) - assert os.path.exists( - jsonl_file - ), f"Provided JSON-L file {jsonl_file} doesn't exist" + assert os.path.exists(jsonl_file), f"Provided JSON-L file {jsonl_file} doesn't exist" elif shared_state.static_task_data is not None: if isinstance(shared_state.static_task_data, types.GeneratorType): # TODO(#97) can we check something about this? @@ -200,9 +194,7 @@ def assert_task_args(cls, args: DictConfig, shared_state: "SharedTaskState"): len([x for x in shared_state.static_task_data]) > 0 ), "Length of data dict provided was 0" else: - raise AssertionError( - "Must provide one of a data csv, json, json-L, or a list of tasks" - ) + raise AssertionError("Must provide one of a data csv, json, json-L, or a list of tasks") def get_initialization_data(self) -> Iterable["InitializationData"]: """ diff --git a/mephisto/abstractions/blueprints/abstract/static_task/static_task_runner.py b/mephisto/abstractions/blueprints/abstract/static_task/static_task_runner.py index 0aec59b5e..5ad76735c 100644 --- a/mephisto/abstractions/blueprints/abstract/static_task/static_task_runner.py +++ b/mephisto/abstractions/blueprints/abstract/static_task/static_task_runner.py @@ -29,9 +29,7 @@ class StaticTaskRunner(TaskRunner): as only one person can work on them at a time """ - def __init__( - self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" - ): + def __init__(self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"): super().__init__(task_run, args, shared_state) self.is_concurrent = False self.assignment_duration_in_seconds = ( diff --git a/mephisto/abstractions/blueprints/mixins/onboarding_required.py b/mephisto/abstractions/blueprints/mixins/onboarding_required.py index 7c9555d2b..c1a9a3892 100644 --- a/mephisto/abstractions/blueprints/mixins/onboarding_required.py +++ b/mephisto/abstractions/blueprints/mixins/onboarding_required.py @@ -84,9 +84,7 @@ def init_mixin_config( self.init_onboarding_config(task_run, args, shared_state) @classmethod - def assert_mixin_args( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ) -> None: + def assert_mixin_args(cls, args: "DictConfig", shared_state: "SharedTaskState") -> None: """Method to validate the incoming args and throw if something won't work""" # Is there any validation that should be done on the onboarding qualification name? return @@ -140,12 +138,8 @@ def init_onboarding_config( db, onboarding_qualification_name, ) - self.onboarding_failed_name = self.get_failed_qual( - onboarding_qualification_name - ) - self.onboarding_failed_id = find_or_create_qualification( - db, self.onboarding_failed_name - ) + self.onboarding_failed_name = self.get_failed_qual(onboarding_qualification_name) + self.onboarding_failed_id = find_or_create_qualification(db, self.onboarding_failed_name) @classmethod def clear_onboarding(self, worker: "Worker", qualification_name: str): @@ -163,9 +157,7 @@ def get_onboarding_data(self, worker_id: str) -> Dict[str, Any]: """ return self.onboarding_data - def validate_onboarding( - self, worker: "Worker", onboarding_agent: "OnboardingAgent" - ) -> bool: + def validate_onboarding(self, worker: "Worker", onboarding_agent: "OnboardingAgent") -> bool: """ Check the incoming onboarding data and evaluate if the worker has passed the qualification or not. Return True if the worker diff --git a/mephisto/abstractions/blueprints/mixins/screen_task_required.py b/mephisto/abstractions/blueprints/mixins/screen_task_required.py index bdff65f3b..5d09e353d 100644 --- a/mephisto/abstractions/blueprints/mixins/screen_task_required.py +++ b/mephisto/abstractions/blueprints/mixins/screen_task_required.py @@ -202,9 +202,7 @@ def get_screening_unit_data(self) -> Optional[Dict[str, Any]]: return None # No screening units left... @classmethod - def create_validation_function( - cls, args: "DictConfig", screen_unit: Callable[["Unit"], bool] - ): + def create_validation_function(cls, args: "DictConfig", screen_unit: Callable[["Unit"], bool]): """ Takes in a validator function to determine if validation units are passable, and returns a `on_unit_submitted` function to be used @@ -219,9 +217,8 @@ def _wrapped_validate(unit): agent = unit.get_assigned_agent() if agent is None: return # Cannot validate a unit with no agent - if ( - args.blueprint.max_screening_units == 0 - and agent.get_worker().is_qualified(passed_qualification_name) + if args.blueprint.max_screening_units == 0 and agent.get_worker().is_qualified( + passed_qualification_name ): return # Do not run validation if screening with regular tasks and worker is already qualified validation_result = screen_unit(unit) @@ -233,9 +230,7 @@ def _wrapped_validate(unit): return _wrapped_validate @classmethod - def get_mixin_qualifications( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ): + def get_mixin_qualifications(cls, args: "DictConfig", shared_state: "SharedTaskState"): """Creates the relevant task qualifications for this task""" passed_qualification_name = args.blueprint.passed_qualification_name failed_qualification_name = args.blueprint.block_qualification diff --git a/mephisto/abstractions/blueprints/mixins/use_gold_unit.py b/mephisto/abstractions/blueprints/mixins/use_gold_unit.py index d30d2d6b0..06cf4bef7 100644 --- a/mephisto/abstractions/blueprints/mixins/use_gold_unit.py +++ b/mephisto/abstractions/blueprints/mixins/use_gold_unit.py @@ -49,9 +49,7 @@ class UseGoldUnitArgs: gold_qualification_base: str = field( default=MISSING, - metadata={ - "help": ("Basename for a qualification that tracks gold completion rates") - }, + metadata={"help": ("Basename for a qualification that tracks gold completion rates")}, ) max_gold_units: int = field( default=MISSING, @@ -69,18 +67,12 @@ class UseGoldUnitArgs: ) min_golds: int = field( default=1, - metadata={ - "help": ( - "Minimum golds a worker needs to complete before getting real units." - ) - }, + metadata={"help": ("Minimum golds a worker needs to complete before getting real units.")}, ) max_incorrect_golds: int = field( default=0, metadata={ - "help": ( - "Maximum number of golds a worker can get incorrect before being disqualified" - ) + "help": ("Maximum number of golds a worker can get incorrect before being disqualified") }, ) @@ -99,10 +91,7 @@ def get_gold_factory(golds: List[Dict[str, Any]]) -> GoldFactory: assert num_golds != 0, "Must provide at least one gold to get_gold_factory" def get_gold_for_worker(worker: "Worker"): - if ( - worker.db_id not in worker_gold_maps - or len(worker_gold_maps[worker.db_id]) == 0 - ): + if worker.db_id not in worker_gold_maps or len(worker_gold_maps[worker.db_id]) == 0: # create a list of gold indices a worker hasn't done worker_gold_maps[worker.db_id] = [x for x in range(num_golds)] # select a random gold index from what remains @@ -151,9 +140,7 @@ def worker_qualifies( @dataclass class GoldUnitSharedState: - get_gold_for_worker: GoldFactory = field( - default_factory=lambda: get_gold_factory([{}]) - ) + get_gold_for_worker: GoldFactory = field(default_factory=lambda: get_gold_factory([{}])) worker_needs_gold: Callable[[int, int, int, int], bool] = field( default_factory=lambda: worker_needs_gold, ) @@ -247,18 +234,14 @@ def assert_mixin_args(cls, args: "DictConfig", shared_state: "SharedTaskState"): # given a worker @staticmethod - def get_current_qual_or_default( - worker: "Worker", qual_name: str, default_val: Any = 0 - ) -> Any: + def get_current_qual_or_default(worker: "Worker", qual_name: str, default_val: Any = 0) -> Any: """Return the qualification of this name for the worker, or the default value""" found_qual = worker.get_granted_qualification(qual_name) return default_val if found_qual is None else found_qual.value def get_completion_stats_for_worker(self, worker: "Worker") -> Tuple[int, int, int]: """Return the correct and incorrect gold counts, as well as the total count for a worker""" - completed_units = UseGoldUnit.get_current_qual_or_default( - worker, self.task_count_qual_name - ) + completed_units = UseGoldUnit.get_current_qual_or_default(worker, self.task_count_qual_name) correct_golds = UseGoldUnit.get_current_qual_or_default( worker, self.golds_correct_qual_name ) @@ -299,9 +282,7 @@ def update_qualified_status(self, worker: "Worker") -> bool: return True return False - def get_gold_unit_data_for_worker( - self, worker: "Worker" - ) -> Optional[Dict[str, Any]]: + def get_gold_unit_data_for_worker(self, worker: "Worker") -> Optional[Dict[str, Any]]: if self.gold_units_launched >= self.gold_unit_cap: return None try: @@ -313,9 +294,7 @@ def get_gold_unit_data_for_worker( return None @classmethod - def create_validation_function( - cls, args: "DictConfig", screen_unit: Callable[["Unit"], bool] - ): + def create_validation_function(cls, args: "DictConfig", screen_unit: Callable[["Unit"], bool]): """ Takes in a validator function to determine if validation units are passable, and returns a `on_unit_submitted` function to be used @@ -330,10 +309,7 @@ def create_validation_function( def _wrapped_validate(unit): agent = unit.get_assigned_agent() if unit.unit_index != GOLD_UNIT_INDEX: - if ( - agent is not None - and agent.get_status() == AgentState.STATUS_COMPLETED - ): + if agent is not None and agent.get_status() == AgentState.STATUS_COMPLETED: worker = agent.get_worker() completed_units = UseGoldUnit.get_current_qual_or_default( worker, task_count_qual_name @@ -370,9 +346,7 @@ def _wrapped_validate(unit): return _wrapped_validate @classmethod - def get_mixin_qualifications( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ): + def get_mixin_qualifications(cls, args: "DictConfig", shared_state: "SharedTaskState"): """Creates the relevant task qualifications for this task""" base_qual_name = args.blueprint.gold_qualification_base golds_failed_qual_name = f"{base_qual_name}-wrong-golds" diff --git a/mephisto/abstractions/blueprints/mock/mock_blueprint.py b/mephisto/abstractions/blueprints/mock/mock_blueprint.py index c3b205d61..5bfbc9d84 100644 --- a/mephisto/abstractions/blueprints/mock/mock_blueprint.py +++ b/mephisto/abstractions/blueprints/mock/mock_blueprint.py @@ -89,9 +89,7 @@ class MockBlueprint(Blueprint, OnboardingRequired, ScreenTaskRequired): ArgsMixin: ClassVar[Any] SharedStateMixin: ClassVar[Any] - def __init__( - self, task_run: "TaskRun", args: "DictConfig", shared_state: "MockSharedState" - ): + def __init__(self, task_run: "TaskRun", args: "DictConfig", shared_state: "MockSharedState"): super().__init__(task_run, args, shared_state) def get_initialization_data(self) -> Iterable[InitializationData]: @@ -103,9 +101,7 @@ def get_initialization_data(self) -> Iterable[InitializationData]: for i in range(self.args.blueprint.num_assignments) ] - def validate_onboarding( - self, worker: "Worker", onboarding_agent: "OnboardingAgent" - ) -> bool: + def validate_onboarding(self, worker: "Worker", onboarding_agent: "OnboardingAgent") -> bool: """ Onboarding validation for MockBlueprints just returns the 'should_pass' field """ diff --git a/mephisto/abstractions/blueprints/mock/mock_task_runner.py b/mephisto/abstractions/blueprints/mock/mock_task_runner.py index f590ac7db..46f6098dd 100644 --- a/mephisto/abstractions/blueprints/mock/mock_task_runner.py +++ b/mephisto/abstractions/blueprints/mock/mock_task_runner.py @@ -24,9 +24,7 @@ class MockTaskRunner(TaskRunner): """Mock of a task runner, for use in testing""" - def __init__( - self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" - ): + def __init__(self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"): super().__init__(task_run, args, shared_state) self.timeout = args.blueprint.timeout_time self.tracked_tasks: Dict[str, Union["Assignment", "Unit"]] = {} @@ -66,9 +64,7 @@ def run_unit(self, unit: "Unit", agent: "Agent"): time.sleep(0.3) assigned_agent = unit.get_assigned_agent() assert assigned_agent is not None, "No agent was assigned" - assert ( - assigned_agent.db_id == agent.db_id - ), "Task was not given to assigned agent" + assert assigned_agent.db_id == agent.db_id, "Task was not given to assigned agent" packet = agent.get_live_update(timeout=self.timeout) if packet is not None: agent.observe(packet) diff --git a/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_blueprint.py b/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_blueprint.py index 70a56f091..e05bfd153 100644 --- a/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_blueprint.py +++ b/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_blueprint.py @@ -122,9 +122,7 @@ class ParlAIChatBlueprintArgs(OnboardingRequiredArgs, BlueprintArgs): ) num_conversations: int = field( default=MISSING, - metadata={ - "help": "Optional count of conversations to have if no context provided" - }, + metadata={"help": "Optional count of conversations to have if no context provided"}, ) @@ -184,9 +182,7 @@ def __init__( self.world_module = world_module assert hasattr(world_module, "make_world") assert hasattr(world_module, "get_world_params") - self.agent_count = world_module.get_world_params()[ # type: ignore - "agent_count" - ] + self.agent_count = world_module.get_world_params()["agent_count"] # type: ignore self.full_task_description = MISSING_SOMETHING_TEXT if args.blueprint.get("task_description_file", None) is not None: @@ -206,9 +202,7 @@ def __init__( self.full_preview_description = description_fp.read() @classmethod - def assert_task_args( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ) -> None: + def assert_task_args(cls, args: "DictConfig", shared_state: "SharedTaskState") -> None: """Ensure that arguments are properly configured to launch this task""" # Find world module assert isinstance( @@ -225,9 +219,7 @@ def assert_task_args( world_module_name = os.path.basename(world_file_path)[:-3] world_module = import_module(world_module_name) # assert world file is valid - assert hasattr( - world_module, "make_world" - ), "Provided world file has no `make_world` method" + assert hasattr(world_module, "make_world"), "Provided world file has no `make_world` method" assert hasattr( world_module, "get_world_params" ), "Provided world file has no `get_world_params` method" @@ -235,35 +227,27 @@ def assert_task_args( # assert some method for determining quantity of conversations if args.blueprint.get("context_csv", None) is not None: csv_file = os.path.expanduser(args.blueprint.context_csv) - assert os.path.exists( - csv_file - ), f"Target context_csv path {csv_file} doesn't exist" + assert os.path.exists(csv_file), f"Target context_csv path {csv_file} doesn't exist" elif args.blueprint.get("context_jsonl", None) is not None: jsonl_file = os.path.expanduser(args.blueprint.context_jsonl) assert os.path.exists( jsonl_file ), f"Target context_jsonl path {jsonl_file} doesn't exist" elif args.blueprint.get("num_conversations", None) is not None: - assert ( - args.blueprint.num_conversations > 0 - ), "Must have at least one conversation" + assert args.blueprint.num_conversations > 0, "Must have at least one conversation" else: raise AssertionError( "Must specify one of --context-csv, --context-jsonl or --num-conversations" ) if args.blueprint.get("custom_source_bundle", None) is not None: - custom_source_file_path = os.path.expanduser( - args.blueprint.custom_source_bundle - ) + custom_source_file_path = os.path.expanduser(args.blueprint.custom_source_bundle) assert os.path.exists( custom_source_file_path ), f"Provided custom bundle doesn't exist at {custom_source_file_path}" if args.blueprint.get("custom_source_dir", None) is not None: - custom_source_dir_path = os.path.expanduser( - args.blueprint.custom_source_dir - ) + custom_source_dir_path = os.path.expanduser(args.blueprint.custom_source_dir) assert os.path.exists( custom_source_dir_path ), f"Provided custom source dir doesn't exist at {custom_source_dir_path}" @@ -298,8 +282,7 @@ def get_frontend_args(self) -> Dict[str, Any]: "preview_html": self.full_preview_description, "frame_height": 650, "chat_title": self.args.task.task_title, - "has_preview": self.args.blueprint.get("preview_source", None) - is not None, + "has_preview": self.args.blueprint.get("preview_source", None) is not None, "block_mobile": True, "frontend_task_opts": shared_state.frontend_task_opts, } @@ -317,9 +300,7 @@ def get_initialization_data(self) -> Iterable["InitializationData"]: for d in self._initialization_data_dicts ] - def validate_onboarding( - self, worker: "Worker", onboarding_agent: "OnboardingAgent" - ) -> bool: + def validate_onboarding(self, worker: "Worker", onboarding_agent: "OnboardingAgent") -> bool: if hasattr(self.world_module, "validate_onboarding"): return self.world_module.validate_onboarding( # type: ignore onboarding_agent.state.get_data() diff --git a/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_task_builder.py b/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_task_builder.py index cc83680ae..f56966770 100644 --- a/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_task_builder.py +++ b/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_task_builder.py @@ -199,9 +199,7 @@ def build_in_dir(self, build_dir: str): # Copy over the static files for this task: for fin_file in ["index.html", "notif.mp3"]: - copied_static_file = os.path.join( - FRONTEND_SOURCE_DIR, "src", "static", fin_file - ) + copied_static_file = os.path.join(FRONTEND_SOURCE_DIR, "src", "static", fin_file) target_path = os.path.join(target_resource_dir, fin_file) shutil.copy2(copied_static_file, target_path) diff --git a/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_task_runner.py b/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_task_runner.py index 71f178676..f2414d7fd 100644 --- a/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_task_runner.py +++ b/mephisto/abstractions/blueprints/parlai_chat/parlai_chat_task_runner.py @@ -84,9 +84,7 @@ def act(self, timeout=None): if gotten_act is None: # No act received, see that one is requested: if not self.__act_requested: - self.mephisto_agent.observe( - {"task_data": {"live_update_requested": True}} - ) + self.mephisto_agent.observe({"task_data": {"live_update_requested": True}}) self.__act_requested = True if timeout is not None: gotten_act = self.mephisto_agent.get_live_update(timeout=timeout) @@ -108,9 +106,7 @@ class ParlAIChatTaskRunner(TaskRunner): Task runner for a parlai chat task """ - def __init__( - self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" - ): + def __init__(self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"): super().__init__(task_run, args, shared_state) from mephisto.abstractions.blueprints.parlai_chat.parlai_chat_blueprint import ( SharedParlAITaskState, @@ -179,10 +175,7 @@ def run_onboarding(self, agent: "OnboardingAgent") -> None: world_id = self.get_world_id("onboard", agent.get_agent_id()) self.id_to_worlds[world_id] = world - while ( - not world.episode_done() - and agent.get_agent_id() in self.running_onboardings - ): + while not world.episode_done() and agent.get_agent_id() in self.running_onboardings: world.parley() # Ensure agent can submit after onboarding diff --git a/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_blueprint.py b/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_blueprint.py index 32f64bb05..b035bcdf8 100644 --- a/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_blueprint.py +++ b/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_blueprint.py @@ -119,9 +119,7 @@ class RemoteProcedureBlueprintArgs( @register_mephisto_abstraction() -class RemoteProcedureBlueprint( - ScreenTaskRequired, OnboardingRequired, UseGoldUnit, Blueprint -): +class RemoteProcedureBlueprint(ScreenTaskRequired, OnboardingRequired, UseGoldUnit, Blueprint): """Blueprint for a task that runs a parlai chat""" AgentStateClass: ClassVar[Type["AgentState"]] = RemoteProcedureAgentState @@ -172,9 +170,7 @@ def __init__( pass @classmethod - def assert_task_args( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ) -> None: + def assert_task_args(cls, args: "DictConfig", shared_state: "SharedTaskState") -> None: """Ensure that arguments are properly configured to launch this task""" assert isinstance( shared_state, SharedRemoteProcedureTaskState @@ -182,19 +178,13 @@ def assert_task_args( blue_args = args.blueprint if blue_args.get("data_csv", None) is not None: csv_file = os.path.expanduser(blue_args.data_csv) - assert os.path.exists( - csv_file - ), f"Provided csv file {csv_file} doesn't exist" + assert os.path.exists(csv_file), f"Provided csv file {csv_file} doesn't exist" elif blue_args.get("data_json", None) is not None: json_file = os.path.expanduser(blue_args.data_json) - assert os.path.exists( - json_file - ), f"Provided JSON file {json_file} doesn't exist" + assert os.path.exists(json_file), f"Provided JSON file {json_file} doesn't exist" elif blue_args.get("data_jsonl", None) is not None: jsonl_file = os.path.expanduser(blue_args.data_jsonl) - assert os.path.exists( - jsonl_file - ), f"Provided JSON-L file {jsonl_file} doesn't exist" + assert os.path.exists(jsonl_file), f"Provided JSON-L file {jsonl_file} doesn't exist" elif shared_state.static_task_data is not None: if isinstance(shared_state.static_task_data, types.GeneratorType): # TODO can we check something about this? @@ -204,9 +194,7 @@ def assert_task_args( len([x for x in shared_state.static_task_data]) > 0 ), "Length of data dict provided was 0" else: - raise AssertionError( - "Must provide one of a data csv, json, json-L, or a list of tasks" - ) + raise AssertionError("Must provide one of a data csv, json, json-L, or a list of tasks") assert shared_state.function_registry is not None, ( "Must provide a valid function registry to use with the task, a mapping " "of function names to functions that take as input a string and an agent " diff --git a/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_task_runner.py b/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_task_runner.py index d83c5867b..8c0ecac31 100644 --- a/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_task_runner.py +++ b/mephisto/abstractions/blueprints/remote_procedure/remote_procedure_task_runner.py @@ -68,9 +68,7 @@ def get_init_data_for_agent(self, agent: "Agent") -> Dict[str, Any]: assert new_state is not None, "Recently initialized state still None" return new_state - def _agent_in_onboarding_or_live( - self, agent: Union["Agent", "OnboardingAgent"] - ) -> bool: + def _agent_in_onboarding_or_live(self, agent: Union["Agent", "OnboardingAgent"]) -> bool: """Determine if an agent server should still be maintained""" return ( agent.get_agent_id() in self.running_units @@ -123,9 +121,7 @@ def run_onboarding(self, agent: "OnboardingAgent") -> None: ): self._run_server_timestep_for_agent(agent) - remaining_time = self.assignment_duration_in_seconds - ( - time.time() - start_time - ) + remaining_time = self.assignment_duration_in_seconds - (time.time() - start_time) agent.await_submit(timeout=remaining_time) def cleanup_onboarding(self, agent: "OnboardingAgent") -> None: @@ -144,9 +140,7 @@ def run_unit(self, unit: "Unit", agent: "Agent") -> None: ): self._run_server_timestep_for_agent(agent) - remaining_time = self.assignment_duration_in_seconds - ( - time.time() - start_time - ) + remaining_time = self.assignment_duration_in_seconds - (time.time() - start_time) agent.await_submit(timeout=remaining_time) def cleanup_unit(self, unit: "Unit") -> None: diff --git a/mephisto/abstractions/blueprints/static_html_task/static_html_blueprint.py b/mephisto/abstractions/blueprints/static_html_task/static_html_blueprint.py index 75d668354..8facfca04 100644 --- a/mephisto/abstractions/blueprints/static_html_task/static_html_blueprint.py +++ b/mephisto/abstractions/blueprints/static_html_task/static_html_blueprint.py @@ -128,34 +128,24 @@ def assert_task_args(cls, args: DictConfig, shared_state: "SharedTaskState"): raise AssertionError("You can't launch an HTML static task on a generator") if blue_args.get("data_csv", None) is not None: csv_file = os.path.expanduser(blue_args.data_csv) - assert os.path.exists( - csv_file - ), f"Provided csv file {csv_file} doesn't exist" + assert os.path.exists(csv_file), f"Provided csv file {csv_file} doesn't exist" elif blue_args.get("data_json", None) is not None: json_file = os.path.expanduser(blue_args.data_json) - assert os.path.exists( - json_file - ), f"Provided JSON file {json_file} doesn't exist" + assert os.path.exists(json_file), f"Provided JSON file {json_file} doesn't exist" elif blue_args.get("data_jsonl", None) is not None: jsonl_file = os.path.expanduser(blue_args.data_jsonl) - assert os.path.exists( - jsonl_file - ), f"Provided JSON-L file {jsonl_file} doesn't exist" + assert os.path.exists(jsonl_file), f"Provided JSON-L file {jsonl_file} doesn't exist" elif shared_state.static_task_data is not None: assert ( len([w for w in shared_state.static_task_data]) > 0 ), "Length of data dict provided was 0" else: - raise AssertionError( - "Must provide one of a data csv, json, json-L, or a list of tasks" - ) + raise AssertionError("Must provide one of a data csv, json, json-L, or a list of tasks") if blue_args.get("onboarding_qualification", None) is not None: assert blue_args.get("onboarding_source", None) is not None, ( - "Must use onboarding html with an onboarding qualification to " - "use onboarding." + "Must use onboarding html with an onboarding qualification to " "use onboarding." ) assert shared_state.validate_onboarding is not None, ( - "Must use an onboarding validation function to use onboarding " - "with static tasks." + "Must use an onboarding validation function to use onboarding " "with static tasks." ) diff --git a/mephisto/abstractions/blueprints/static_react_task/static_react_blueprint.py b/mephisto/abstractions/blueprints/static_react_task/static_react_blueprint.py index 4f3373c02..68bae6fac 100644 --- a/mephisto/abstractions/blueprints/static_react_task/static_react_blueprint.py +++ b/mephisto/abstractions/blueprints/static_react_task/static_react_blueprint.py @@ -85,9 +85,7 @@ class StaticReactBlueprint(StaticBlueprint): ArgsClass = StaticReactBlueprintArgs BLUEPRINT_TYPE = BLUEPRINT_TYPE_STATIC_REACT - def __init__( - self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState" - ): + def __init__(self, task_run: "TaskRun", args: "DictConfig", shared_state: "SharedTaskState"): assert isinstance( shared_state, SharedStaticTaskState ), "Cannot initialize with a non-static state" @@ -99,9 +97,7 @@ def __init__( ) @classmethod - def assert_task_args( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ) -> None: + def assert_task_args(cls, args: "DictConfig", shared_state: "SharedTaskState") -> None: """Ensure that static requirements are fulfilled, and source file exists""" assert isinstance( shared_state, SharedStaticTaskState diff --git a/mephisto/abstractions/crowd_provider.py b/mephisto/abstractions/crowd_provider.py index a2707f5c5..49325de9f 100644 --- a/mephisto/abstractions/crowd_provider.py +++ b/mephisto/abstractions/crowd_provider.py @@ -126,9 +126,7 @@ def setup_resources_for_task_run( raise NotImplementedError() @abstractmethod - def cleanup_resources_from_task_run( - self, task_run: "TaskRun", server_url: str - ) -> None: + def cleanup_resources_from_task_run(self, task_run: "TaskRun", server_url: str) -> None: """ Destroy any resources set up specifically for this task run """ diff --git a/mephisto/abstractions/database.py b/mephisto/abstractions/database.py index 5b4812913..94ee46ac8 100644 --- a/mephisto/abstractions/database.py +++ b/mephisto/abstractions/database.py @@ -45,9 +45,7 @@ class EntryDoesNotExistException(MephistoDBException): # Initialize histogram for database latency -DATABASE_LATENCY = Histogram( - "database_latency_seconds", "Logging for db requests", ["method"] -) +DATABASE_LATENCY = Histogram("database_latency_seconds", "Logging for db requests", ["method"]) # Need all the specific decorators b/c cascading is not allowed in decorators # thanks to https://mail.python.org/pipermail/python-dev/2004-August/046711.html NEW_PROJECT_LATENCY = DATABASE_LATENCY.labels(method="new_project") @@ -78,17 +76,11 @@ class EntryDoesNotExistException(MephistoDBException): GET_AGENT_LATENCY = DATABASE_LATENCY.labels(method="get_agent") FIND_AGENTS_LATENCY = DATABASE_LATENCY.labels(method="find_agents") UPDATE_AGENT_LATENCY = DATABASE_LATENCY.labels(method="update_agent") -CLEAR_UNIT_AGENT_ASSIGNMENT_LATENCY = DATABASE_LATENCY.labels( - method="clear_unit_agent_assignment" -) +CLEAR_UNIT_AGENT_ASSIGNMENT_LATENCY = DATABASE_LATENCY.labels(method="clear_unit_agent_assignment") NEW_ONBOARDING_AGENT_LATENCY = DATABASE_LATENCY.labels(method="new_onboarding_agent") GET_ONBOARDING_AGENT_LATENCY = DATABASE_LATENCY.labels(method="get_onboarding_agent") -FIND_ONBOARDING_AGENTS_LATENCY = DATABASE_LATENCY.labels( - method="find_onboarding_agents" -) -UPDATE_ONBOARDING_AGENT_LATENCY = DATABASE_LATENCY.labels( - method="update_onboarding_agent" -) +FIND_ONBOARDING_AGENTS_LATENCY = DATABASE_LATENCY.labels(method="find_onboarding_agents") +UPDATE_ONBOARDING_AGENT_LATENCY = DATABASE_LATENCY.labels(method="update_onboarding_agent") MAKE_QUALIFICATION_LATENCY = DATABASE_LATENCY.labels(method="make_qualification") GET_QUALIFICATION_LATENCY = DATABASE_LATENCY.labels(method="get_qualification") FIND_QUALIFICATIONS_LATENCY = DATABASE_LATENCY.labels(method="find_qualifications") @@ -98,9 +90,7 @@ class EntryDoesNotExistException(MephistoDBException): CHECK_GRANTED_QUALIFICATIONS_LATENCY = DATABASE_LATENCY.labels( method="check_granted_qualifications" ) -GET_GRANTED_QUALIFICATION_LATENCY = DATABASE_LATENCY.labels( - method="get_granted_qualification" -) +GET_GRANTED_QUALIFICATION_LATENCY = DATABASE_LATENCY.labels(method="get_granted_qualification") REVOKE_QUALIFICATION_LATENCY = DATABASE_LATENCY.labels(method="revoke_qualification") @@ -238,9 +228,7 @@ def new_task( Create a new task with the given task name. Raise EntryAlreadyExistsException if a task with this name has already been created. """ - return self._new_task( - task_name=task_name, task_type=task_type, project_id=project_id - ) + return self._new_task(task_name=task_name, task_type=task_type, project_id=project_id) @abstractmethod def _get_task(self, task_id: str) -> Mapping[str, Any]: @@ -633,9 +621,7 @@ def new_requester(self, requester_name: str, provider_type: str) -> str: Raises EntryAlreadyExistsException if there is already a requester with this name """ - return self._new_requester( - requester_name=requester_name, provider_type=provider_type - ) + return self._new_requester(requester_name=requester_name, provider_type=provider_type) @abstractmethod def _get_requester(self, requester_id: str) -> Mapping[str, Any]: @@ -667,9 +653,7 @@ def find_requesters( Try to find any requester that matches the above. When called with no arguments, return all requesters. """ - return self._find_requesters( - requester_name=requester_name, provider_type=provider_type - ) + return self._find_requesters(requester_name=requester_name, provider_type=provider_type) @abstractmethod def _new_worker(self, worker_name: str, provider_type: str) -> str: @@ -886,9 +870,7 @@ def update_onboarding_agent( Update the given onboarding agent with the given parameters if possible, raise appropriate exception otherwise. """ - return self._update_onboarding_agent( - onboarding_agent_id=onboarding_agent_id, status=status - ) + return self._update_onboarding_agent(onboarding_agent_id=onboarding_agent_id, status=status) @abstractmethod def _find_onboarding_agents( @@ -937,16 +919,12 @@ def make_qualification(self, qualification_name: str) -> str: return self._make_qualification(qualification_name=qualification_name) @abstractmethod - def _find_qualifications( - self, qualification_name: Optional[str] = None - ) -> List[Qualification]: + def _find_qualifications(self, qualification_name: Optional[str] = None) -> List[Qualification]: """find_qualifications implementation""" raise NotImplementedError() @FIND_QUALIFICATIONS_LATENCY.time() - def find_qualifications( - self, qualification_name: Optional[str] = None - ) -> List[Qualification]: + def find_qualifications(self, qualification_name: Optional[str] = None) -> List[Qualification]: """ Find a qualification. If no name is supplied, returns all qualifications. """ @@ -987,7 +965,8 @@ def delete_qualification(self, qualification_name: str) -> None: @FIND_GRANT_QUALIFICATION_LATENCY.time() def find_granted_qualifications( - self, worker_id: Optional[str] = None, + self, + worker_id: Optional[str] = None, ) -> List[GrantedQualification]: """ Find granted qualifications. @@ -996,16 +975,12 @@ def find_granted_qualifications( return self._check_granted_qualifications(worker_id=worker_id) @abstractmethod - def _grant_qualification( - self, qualification_id: str, worker_id: str, value: int = 1 - ) -> None: + def _grant_qualification(self, qualification_id: str, worker_id: str, value: int = 1) -> None: """grant_qualification implementation""" raise NotImplementedError() @GRANT_QUALIFICATION_LATENCY.time() - def grant_qualification( - self, qualification_id: str, worker_id: str, value: int = 1 - ) -> None: + def grant_qualification(self, qualification_id: str, worker_id: str, value: int = 1) -> None: """ Grant a worker the given qualification. Update the qualification value if it already exists @@ -1047,9 +1022,7 @@ def _get_granted_qualification( raise NotImplementedError() @GET_GRANTED_QUALIFICATION_LATENCY.time() - def get_granted_qualification( - self, qualification_id: str, worker_id: str - ) -> Mapping[str, Any]: + def get_granted_qualification(self, qualification_id: str, worker_id: str) -> Mapping[str, Any]: """ Return the granted qualification in the database between the given worker and qualification id @@ -1070,9 +1043,7 @@ def revoke_qualification(self, qualification_id: str, worker_id: str) -> None: """ Remove the given qualification from the given worker """ - return self._revoke_qualification( - qualification_id=qualification_id, worker_id=worker_id - ) + return self._revoke_qualification(qualification_id=qualification_id, worker_id=worker_id) # File/blob manipulation methods diff --git a/mephisto/abstractions/databases/local_database.py b/mephisto/abstractions/databases/local_database.py index 9d8b1168f..8f759a48e 100644 --- a/mephisto/abstractions/databases/local_database.py +++ b/mephisto/abstractions/databases/local_database.py @@ -290,9 +290,7 @@ def init_tables(self) -> None: c.execute(CREATE_ONBOARDING_AGENTS_TABLE) c.executescript(CREATE_CORE_INDEXES) - def __get_one_by_id( - self, table_name: str, id_name: str, db_id: str - ) -> Mapping[str, Any]: + def __get_one_by_id(self, table_name: str, id_name: str, db_id: str) -> Mapping[str, Any]: """ Try to request the row for the given table and entry, raise EntryDoesNotExistException if it isn't present @@ -309,9 +307,7 @@ def __get_one_by_id( ) results = c.fetchall() if len(results) != 1: - raise EntryDoesNotExistException( - f"Table {table_name} has no {id_name} {db_id}" - ) + raise EntryDoesNotExistException(f"Table {table_name} has no {id_name} {db_id}") return results[0] def __create_query_and_tuple( @@ -335,9 +331,7 @@ def __create_query_and_tuple( return "", () query_lines = [ - f"WHERE {arg_name} = ?{idx+1}\n" - if idx == 0 - else f"AND {arg_name} = ?{idx+1}\n" + f"WHERE {arg_name} = ?{idx+1}\n" if idx == 0 else f"AND {arg_name} = ?{idx+1}\n" for idx, arg_name in enumerate(fin_args) ] @@ -353,18 +347,14 @@ def _new_project(self, project_name: str) -> str: with self.table_access_condition, self._get_connection() as conn: c = conn.cursor() try: - c.execute( - "INSERT INTO projects(project_name) VALUES (?);", (project_name,) - ) + c.execute("INSERT INTO projects(project_name) VALUES (?);", (project_name,)) project_id = str(c.lastrowid) return project_id except sqlite3.IntegrityError as e: if is_key_failure(e): raise EntryDoesNotExistException() elif is_unique_failure(e): - raise EntryAlreadyExistsException( - f"Project {project_name} already exists" - ) + raise EntryAlreadyExistsException(f"Project {project_name} already exists") raise MephistoDBException(e) def _get_project(self, project_id: str) -> Mapping[str, Any]: @@ -395,10 +385,7 @@ def _find_projects(self, project_name: Optional[str] = None) -> List[Project]: arg_tuple, ) rows = c.fetchall() - return [ - Project(self, str(r["project_id"]), row=r, _used_new_call=True) - for r in rows - ] + return [Project(self, str(r["project_id"]), row=r, _used_new_call=True) for r in rows] def _new_task( self, @@ -471,9 +458,7 @@ def _find_tasks( arg_tuple, ) rows = c.fetchall() - return [ - Task(self, str(r["task_id"]), row=r, _used_new_call=True) for r in rows - ] + return [Task(self, str(r["task_id"]), row=r, _used_new_call=True) for r in rows] def _update_task( self, @@ -518,9 +503,7 @@ def _update_task( if is_key_failure(e): raise EntryDoesNotExistException(e) elif is_unique_failure(e): - raise EntryAlreadyExistsException( - f"Task name {task_name} is already in use" - ) + raise EntryAlreadyExistsException(f"Task name {task_name} is already in use") raise MephistoDBException(e) def _new_task_run( @@ -600,10 +583,7 @@ def _find_task_runs( arg_tuple, ) rows = c.fetchall() - return [ - TaskRun(self, str(r["task_run_id"]), row=r, _used_new_call=True) - for r in rows - ] + return [TaskRun(self, str(r["task_run_id"]), row=r, _used_new_call=True) for r in rows] def _update_task_run(self, task_run_id: str, is_completed: bool): """ @@ -713,8 +693,7 @@ def _find_assignments( ) rows = c.fetchall() return [ - Assignment(self, str(r["assignment_id"]), row=r, _used_new_call=True) - for r in rows + Assignment(self, str(r["assignment_id"]), row=r, _used_new_call=True) for r in rows ] def _new_unit( @@ -837,9 +816,7 @@ def _find_units( arg_tuple, ) rows = c.fetchall() - return [ - Unit(self, str(r["unit_id"]), row=r, _used_new_call=True) for r in rows - ] + return [Unit(self, str(r["unit_id"]), row=r, _used_new_call=True) for r in rows] def _clear_unit_agent_assignment(self, unit_id: str) -> None: """ @@ -954,8 +931,7 @@ def _find_requesters( ) rows = c.fetchall() return [ - Requester(self, str(r["requester_id"]), row=r, _used_new_call=True) - for r in rows + Requester(self, str(r["requester_id"]), row=r, _used_new_call=True) for r in rows ] def _new_worker(self, worker_name: str, provider_type: str) -> str: @@ -1014,10 +990,7 @@ def _find_workers( arg_tuple, ) rows = c.fetchall() - return [ - Worker(self, str(r["worker_id"]), row=r, _used_new_call=True) - for r in rows - ] + return [Worker(self, str(r["worker_id"]), row=r, _used_new_call=True) for r in rows] def _new_agent( self, @@ -1155,10 +1128,7 @@ def _find_agents( arg_tuple, ) rows = c.fetchall() - return [ - Agent(self, str(r["agent_id"]), row=r, _used_new_call=True) - for r in rows - ] + return [Agent(self, str(r["agent_id"]), row=r, _used_new_call=True) for r in rows] def _make_qualification(self, qualification_name: str) -> str: """ @@ -1181,9 +1151,7 @@ def _make_qualification(self, qualification_name: str) -> str: raise EntryAlreadyExistsException() raise MephistoDBException(e) - def _find_qualifications( - self, qualification_name: Optional[str] = None - ) -> List[Qualification]: + def _find_qualifications(self, qualification_name: Optional[str] = None) -> List[Qualification]: """ Find a qualification. If no name is supplied, returns all qualifications. """ @@ -1202,9 +1170,7 @@ def _find_qualifications( ) rows = c.fetchall() return [ - Qualification( - self, str(r["qualification_id"]), row=r, _used_new_call=True - ) + Qualification(self, str(r["qualification_id"]), row=r, _used_new_call=True) for r in rows ] @@ -1215,9 +1181,7 @@ def _get_qualification(self, qualification_id: str) -> Mapping[str, Any]: See Qualification for the expected fields for the returned mapping """ - return self.__get_one_by_id( - "qualifications", "qualification_id", qualification_id - ) + return self.__get_one_by_id("qualifications", "qualification_id", qualification_id) def _delete_qualification(self, qualification_name: str) -> None: """ @@ -1225,9 +1189,7 @@ def _delete_qualification(self, qualification_name: str) -> None: """ qualifications = self.find_qualifications(qualification_name=qualification_name) if len(qualifications) == 0: - raise EntryDoesNotExistException( - f"No qualification found by name {qualification_name}" - ) + raise EntryDoesNotExistException(f"No qualification found by name {qualification_name}") qualification = qualifications[0] with self.table_access_condition, self._get_connection() as conn: c = conn.cursor() @@ -1240,9 +1202,7 @@ def _delete_qualification(self, qualification_name: str) -> None: (qualification_name,), ) - def _grant_qualification( - self, qualification_id: str, worker_id: str, value: int = 1 - ) -> None: + def _grant_qualification(self, qualification_id: str, worker_id: str, value: int = 1) -> None: """ Grant a worker the given qualification. Update the qualification value if it already exists @@ -1311,7 +1271,10 @@ def _check_granted_qualifications( rows = c.fetchall() return [ GrantedQualification( - self, str(r["qualification_id"]), str(r["worker_id"]), row=r, + self, + str(r["qualification_id"]), + str(r["worker_id"]), + row=r, ) for r in rows ] @@ -1396,9 +1359,7 @@ def _get_onboarding_agent(self, onboarding_agent_id: str) -> Mapping[str, Any]: Returns a SQLite Row object with the expected fields """ - return self.__get_one_by_id( - "onboarding_agents", "onboarding_agent_id", onboarding_agent_id - ) + return self.__get_one_by_id("onboarding_agents", "onboarding_agent_id", onboarding_agent_id) def _update_onboarding_agent( self, onboarding_agent_id: str, status: Optional[str] = None @@ -1461,9 +1422,7 @@ def _find_onboarding_agents( ) rows = c.fetchall() return [ - OnboardingAgent( - self, str(r["onboarding_agent_id"]), row=r, _used_new_call=True - ) + OnboardingAgent(self, str(r["onboarding_agent_id"]), row=r, _used_new_call=True) for r in rows ] diff --git a/mephisto/abstractions/providers/mock/mock_provider.py b/mephisto/abstractions/providers/mock/mock_provider.py index 27ea7f0ba..e3ad456b2 100644 --- a/mephisto/abstractions/providers/mock/mock_provider.py +++ b/mephisto/abstractions/providers/mock/mock_provider.py @@ -71,9 +71,7 @@ def setup_resources_for_task_run( """Mocks don't do any initialization""" return None - def cleanup_resources_from_task_run( - self, task_run: "TaskRun", server_url: str - ) -> None: + def cleanup_resources_from_task_run(self, task_run: "TaskRun", server_url: str) -> None: """Mocks don't do any initialization""" return None diff --git a/mephisto/abstractions/providers/mock/mock_requester.py b/mephisto/abstractions/providers/mock/mock_requester.py index 190a3d4a5..5065c8364 100644 --- a/mephisto/abstractions/providers/mock/mock_requester.py +++ b/mephisto/abstractions/providers/mock/mock_requester.py @@ -29,9 +29,7 @@ class MockRequesterArgs(RequesterArgs): "required": True, }, ) - force_fail: bool = field( - default=False, metadata={"help": "Trigger a failed registration"} - ) + force_fail: bool = field(default=False, metadata={"help": "Trigger a failed registration"}) class MockRequester(Requester): diff --git a/mephisto/abstractions/providers/mock/mock_unit.py b/mephisto/abstractions/providers/mock/mock_unit.py index de12d11a4..1153c74ad 100644 --- a/mephisto/abstractions/providers/mock/mock_unit.py +++ b/mephisto/abstractions/providers/mock/mock_unit.py @@ -48,15 +48,10 @@ def launch(self, task_url: str) -> None: # TODO(OWN) get this link to the frontend port = task_url.split(":")[1].split("/")[0] if port: - assignment_url = ( - f"http://localhost:{port}/?worker_id=x&assignment_id={self.db_id}" - ) + assignment_url = f"http://localhost:{port}/?worker_id=x&assignment_id={self.db_id}" else: assignment_url = f"{task_url}/?worker_id=x&assignment_id={self.db_id}" - print( - f"Mock task launched: http://localhost:{port} for preview, " - f"{assignment_url}" - ) + print(f"Mock task launched: http://localhost:{port} for preview, " f"{assignment_url}") logger.info( f"Mock task launched: http://localhost:{port} for preview, " f"{assignment_url} for assignment {self.assignment_id}" @@ -79,8 +74,6 @@ def is_expired(self) -> bool: return self.datastore.get_unit_expired(self.db_id) @staticmethod - def new( - db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float - ) -> "Unit": + def new(db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float) -> "Unit": """Create a Unit for the given assignment""" return MockUnit._register_unit(db, assignment, index, pay_amount, PROVIDER_TYPE) diff --git a/mephisto/abstractions/providers/mturk/mturk_agent.py b/mephisto/abstractions/providers/mturk/mturk_agent.py index 6ff3e2245..23596ad7a 100644 --- a/mephisto/abstractions/providers/mturk/mturk_agent.py +++ b/mephisto/abstractions/providers/mturk/mturk_agent.py @@ -50,9 +50,7 @@ def __init__( _used_new_call: bool = False, ): super().__init__(db, db_id, row=row, _used_new_call=_used_new_call) - self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider( - self.PROVIDER_TYPE - ) + self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider(self.PROVIDER_TYPE) unit: "MTurkUnit" = cast("MTurkUnit", self.get_unit()) self.mturk_assignment_id = unit.get_mturk_assignment_id() @@ -83,12 +81,8 @@ def new_from_provider_data( """ from mephisto.abstractions.providers.mturk.mturk_unit import MTurkUnit - assert isinstance( - unit, MTurkUnit - ), "Can only register mturk agents to mturk units" - unit.register_from_provider_data( - provider_data["hit_id"], provider_data["assignment_id"] - ) + assert isinstance(unit, MTurkUnit), "Can only register mturk agents to mturk units" + unit.register_from_provider_data(provider_data["hit_id"], provider_data["assignment_id"]) return super().new_from_provider_data(db, worker, unit, provider_data) def attempt_to_reconcile_submitted_data(self, mturk_hit_id: str): @@ -101,9 +95,7 @@ def attempt_to_reconcile_submitted_data(self, mturk_hit_id: str): assignment = get_assignments_for_hit(client, mturk_hit_id)[0] xml_data = xmltodict.parse(assignment["Answer"]) paired_data = json.loads(json.dumps(xml_data["QuestionFormAnswers"]["Answer"])) - parsed_data = { - entry["QuestionIdentifier"]: entry["FreeText"] for entry in paired_data - } + parsed_data = {entry["QuestionIdentifier"]: entry["FreeText"] for entry in paired_data} parsed_data["MEPHISTO_MTURK_RECONCILED"] = True self.handle_submit(parsed_data) diff --git a/mephisto/abstractions/providers/mturk/mturk_datastore.py b/mephisto/abstractions/providers/mturk/mturk_datastore.py index d874127af..bae6b252c 100644 --- a/mephisto/abstractions/providers/mturk/mturk_datastore.py +++ b/mephisto/abstractions/providers/mturk/mturk_datastore.py @@ -201,9 +201,7 @@ def register_assignment_to_hit( if len(results) > 0 and results[0]["unit_id"] is not None: old_unit_id = results[0]["unit_id"] self._mark_hit_mapping_update(old_unit_id) - logger.debug( - f"Cleared HIT mapping cache for previous unit, {old_unit_id}" - ) + logger.debug(f"Cleared HIT mapping cache for previous unit, {old_unit_id}") c.execute( """UPDATE hits @@ -338,9 +336,7 @@ def create_qualification_mapping( f"Multiple mturk mapping creations for qualification {qualification_name}. " f"Found existing one: {qual}. " ) - assert ( - qual is not None - ), "Cannot be none given is_unique_failure on insert" + assert qual is not None, "Cannot be none given is_unique_failure on insert" cur_requester_id = qual["requester_id"] cur_mturk_qualification_name = qual["mturk_qualification_name"] cur_mturk_qualification_id = qual["mturk_qualification_id"] @@ -358,9 +354,7 @@ def create_qualification_mapping( else: raise e - def get_qualification_mapping( - self, qualification_name: str - ) -> Optional[sqlite3.Row]: + def get_qualification_mapping(self, qualification_name: str) -> Optional[sqlite3.Row]: """Get the mapping between Mephisto qualifications and MTurk qualifications""" with self.table_access_condition: conn = self._get_connection() @@ -383,9 +377,7 @@ def get_session_for_requester(self, requester_name: str) -> boto3.Session: the existing one if it has already been created """ if requester_name not in self.session_storage: - session = boto3.Session( - profile_name=requester_name, region_name=MTURK_REGION_NAME - ) + session = boto3.Session(profile_name=requester_name, region_name=MTURK_REGION_NAME) self.session_storage[requester_name] = session return self.session_storage[requester_name] diff --git a/mephisto/abstractions/providers/mturk/mturk_provider.py b/mephisto/abstractions/providers/mturk/mturk_provider.py index c5449ef2e..a6e836c38 100644 --- a/mephisto/abstractions/providers/mturk/mturk_provider.py +++ b/mephisto/abstractions/providers/mturk/mturk_provider.py @@ -94,19 +94,16 @@ def setup_resources_for_task_run( qualifications = [] for qualification in shared_state.qualifications: applicable_providers = qualification["applicable_providers"] - if ( - applicable_providers is None - or self.PROVIDER_TYPE in applicable_providers - ): + if applicable_providers is None or self.PROVIDER_TYPE in applicable_providers: qualifications.append(qualification) for qualification in qualifications: qualification_name = qualification["qualification_name"] if requester.PROVIDER_TYPE == "mturk_sandbox": qualification_name += "_sandbox" if self.datastore.get_qualification_mapping(qualification_name) is None: - qualification[ - "QualificationTypeId" - ] = requester._create_new_mturk_qualification(qualification_name) + qualification["QualificationTypeId"] = requester._create_new_mturk_qualification( + qualification_name + ) if hasattr(shared_state, "mturk_specific_qualifications"): # TODO(OWN) standardize provider-specific qualifications @@ -115,14 +112,10 @@ def setup_resources_for_task_run( # Set up HIT type client = self._get_client(requester._requester_name) hit_type_id = create_hit_type(client, task_args, qualifications) - frame_height = ( - task_run.get_blueprint().get_frontend_args().get("frame_height", 0) - ) + frame_height = task_run.get_blueprint().get_frontend_args().get("frame_height", 0) self.datastore.register_run(task_run_id, hit_type_id, config_dir, frame_height) - def cleanup_resources_from_task_run( - self, task_run: "TaskRun", server_url: str - ) -> None: + def cleanup_resources_from_task_run(self, task_run: "TaskRun", server_url: str) -> None: """No cleanup necessary for task type""" pass diff --git a/mephisto/abstractions/providers/mturk/mturk_requester.py b/mephisto/abstractions/providers/mturk/mturk_requester.py index 699237974..8e4614226 100644 --- a/mephisto/abstractions/providers/mturk/mturk_requester.py +++ b/mephisto/abstractions/providers/mturk/mturk_requester.py @@ -74,9 +74,7 @@ def __init__( _used_new_call: bool = False, ): super().__init__(db, db_id, row=row, _used_new_call=_used_new_call) - self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider( - self.PROVIDER_TYPE - ) + self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider(self.PROVIDER_TYPE) # Use _requester_name to preserve sandbox behavior which # utilizes a different requester_name self._requester_name = self.requester_name @@ -97,9 +95,7 @@ def register(self, args: Optional[DictConfig] = None) -> None: """ for req_field in ["access_key_id", "secret_access_key"]: if args is not None and req_field not in args: - raise Exception( - f'Missing IAM "{req_field}" in requester registration args' - ) + raise Exception(f'Missing IAM "{req_field}" in requester registration args') setup_aws_credentials(self._requester_name, args) def is_registered(self) -> bool: diff --git a/mephisto/abstractions/providers/mturk/mturk_unit.py b/mephisto/abstractions/providers/mturk/mturk_unit.py index 94aeed797..b964bcf70 100644 --- a/mephisto/abstractions/providers/mturk/mturk_unit.py +++ b/mephisto/abstractions/providers/mturk/mturk_unit.py @@ -53,9 +53,7 @@ def __init__( _used_new_call: bool = False, ): super().__init__(db, db_id, row=row, _used_new_call=_used_new_call) - self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider( - self.PROVIDER_TYPE - ) + self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider(self.PROVIDER_TYPE) self.hit_id: Optional[str] = None self._last_sync_time = 0.0 self._sync_hit_mapping() @@ -86,13 +84,9 @@ def _sync_hit_mapping(self) -> None: # value the moment it's registered self._last_sync_time = time.monotonic() - 1 - def register_from_provider_data( - self, hit_id: str, mturk_assignment_id: str - ) -> None: + def register_from_provider_data(self, hit_id: str, mturk_assignment_id: str) -> None: """Update the datastore and local information from this registration""" - self.datastore.register_assignment_to_hit( - hit_id, self.db_id, mturk_assignment_id - ) + self.datastore.register_assignment_to_hit(hit_id, self.db_id, mturk_assignment_id) self._sync_hit_mapping() def get_mturk_assignment_id(self) -> Optional[str]: @@ -133,9 +127,7 @@ def set_db_status(self, status: str) -> None: ) try: hit_id = self.get_mturk_hit_id() - assert ( - hit_id is not None - ), f"This unit does not have an ID! {self}" + assert hit_id is not None, f"This unit does not have an ID! {self}" agent.attempt_to_reconcile_submitted_data(hit_id) except Exception as e: @@ -365,13 +357,11 @@ def is_expired(self) -> bool: return self.get_status() == AssignmentState.EXPIRED @staticmethod - def new( - db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float - ) -> "Unit": + def new(db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float) -> "Unit": """Create a Unit for the given assignment""" - return MTurkUnit._register_unit( - db, assignment, index, pay_amount, PROVIDER_TYPE - ) + return MTurkUnit._register_unit(db, assignment, index, pay_amount, PROVIDER_TYPE) def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.db_id}, {self.get_mturk_hit_id()}, {self.db_status})" + return ( + f"{self.__class__.__name__}({self.db_id}, {self.get_mturk_hit_id()}, {self.db_status})" + ) diff --git a/mephisto/abstractions/providers/mturk/mturk_utils.py b/mephisto/abstractions/providers/mturk/mturk_utils.py index e477b374d..d1367f2ba 100644 --- a/mephisto/abstractions/providers/mturk/mturk_utils.py +++ b/mephisto/abstractions/providers/mturk/mturk_utils.py @@ -38,9 +38,7 @@ botoconfig = Config(retries=dict(max_attempts=10)) -QUALIFICATION_TYPE_EXISTS_MESSAGE = ( - "You have already created a QualificationType with this name." -) +QUALIFICATION_TYPE_EXISTS_MESSAGE = "You have already created a QualificationType with this name." def client_is_sandbox(client: MTurkClient) -> bool: @@ -60,9 +58,7 @@ def check_aws_credentials(profile_name: str) -> bool: return False -def setup_aws_credentials( - profile_name: str, register_args: Optional[DictConfig] = None -) -> bool: +def setup_aws_credentials(profile_name: str, register_args: Optional[DictConfig] = None) -> bool: if not os.path.exists(os.path.expanduser("~/.aws/")): os.makedirs(os.path.expanduser("~/.aws/")) aws_credentials_file_path = "~/.aws/credentials" @@ -81,15 +77,11 @@ def setup_aws_credentials( # iterating to get the profile for credentialIndex in range(0, len(aws_credentials)): - if str(aws_credentials[credentialIndex]).startswith( - "[{}]".format(profile_name) - ): - aws_credentials[ - credentialIndex + 1 - ] = "aws_access_key_id={}".format(register_args.access_key_id) - aws_credentials[ - credentialIndex + 2 - ] = "aws_secret_access_key={}".format( + if str(aws_credentials[credentialIndex]).startswith("[{}]".format(profile_name)): + aws_credentials[credentialIndex + 1] = "aws_access_key_id={}".format( + register_args.access_key_id + ) + aws_credentials[credentialIndex + 2] = "aws_secret_access_key={}".format( register_args.secret_access_key ) break @@ -136,17 +128,9 @@ def setup_aws_credentials( aws_credentials_file.write("\n\n") # Write login details aws_credentials_file.write("[{}]\n".format(profile_name)) - aws_credentials_file.write( - "aws_access_key_id={}\n".format(aws_access_key_id) - ) - aws_credentials_file.write( - "aws_secret_access_key={}\n".format(aws_secret_access_key) - ) - print( - "AWS credentials successfully saved in {} file.\n".format( - aws_credentials_file_path - ) - ) + aws_credentials_file.write("aws_access_key_id={}\n".format(aws_access_key_id)) + aws_credentials_file.write("aws_secret_access_key={}\n".format(aws_secret_access_key)) + print("AWS credentials successfully saved in {} file.\n".format(aws_credentials_file_path)) return True @@ -360,9 +344,7 @@ def remove_worker_qualification( ) -def convert_mephisto_qualifications( - client: MTurkClient, qualifications: List[Dict[str, Any]] -): +def convert_mephisto_qualifications(client: MTurkClient, qualifications: List[Dict[str, Any]]): """Convert qualifications from mephisto's format to MTurk's""" converted_qualifications = [] for qualification in qualifications: @@ -540,9 +522,7 @@ def create_compensation_hit_with_hit_type( url_target = "workersandbox" if not is_sandbox: url_target = "www" - hit_link = "https://{}.mturk.com/mturk/preview?groupId={}".format( - url_target, hit_type_id - ) + hit_link = "https://{}.mturk.com/mturk/preview?groupId={}".format(url_target, hit_type_id) return hit_link, hit_id, response @@ -586,9 +566,7 @@ def create_hit_with_hit_type( url_target = "workersandbox" if not is_sandbox: url_target = "www" - hit_link = "https://{}.mturk.com/mturk/preview?groupId={}".format( - url_target, hit_type_id - ) + hit_link = "https://{}.mturk.com/mturk/preview?groupId={}".format(url_target, hit_type_id) return hit_link, hit_id, response @@ -604,9 +582,7 @@ def get_hit(client: MTurkClient, hit_id: str) -> Dict[str, Any]: try: return client.get_hit(HITId=hit_id) except ClientError as er: - logger.warning( - f"Skipping HIT {hit_id}. Unable to retrieve due to ClientError: {er}." - ) + logger.warning(f"Skipping HIT {hit_id}. Unable to retrieve due to ClientError: {er}.") return {} @@ -623,14 +599,10 @@ def get_assignments_for_hit(client: MTurkClient, hit_id: str) -> List[Dict[str, return assignments_info.get("Assignments", []) -def approve_work( - client: MTurkClient, assignment_id: str, override_rejection: bool = False -) -> None: +def approve_work(client: MTurkClient, assignment_id: str, override_rejection: bool = False) -> None: """approve work for a given assignment through the mturk client""" try: - client.approve_assignment( - AssignmentId=assignment_id, OverrideRejection=override_rejection - ) + client.approve_assignment(AssignmentId=assignment_id, OverrideRejection=override_rejection) except Exception as e: logger.exception( f"Approving MTurk assignment failed, likely because it has auto-approved. Details: {e}", @@ -649,18 +621,14 @@ def reject_work(client: MTurkClient, assignment_id: str, reason: str) -> None: ) -def approve_assignments_for_hit( - client: MTurkClient, hit_id: str, override_rejection: bool = False -): +def approve_assignments_for_hit(client: MTurkClient, hit_id: str, override_rejection: bool = False): """Approve work for assignments associated with a given hit, through mturk client """ assignments = get_assignments_for_hit(client, hit_id) for assignment in assignments: assignment_id = assignment["AssignmentId"] - client.approve_assignment( - AssignmentId=assignment_id, OverrideRejection=override_rejection - ) + client.approve_assignment(AssignmentId=assignment_id, OverrideRejection=override_rejection) def block_worker(client: MTurkClient, worker_id: str, reason: str) -> None: @@ -751,9 +719,7 @@ def expire_and_dispose_hits( try: client.delete_hit(HITId=h["HITId"]) except Exception as e: - client.update_expiration_for_hit( - HITId=h["HITId"], ExpireAt=datetime(2015, 1, 1) - ) + client.update_expiration_for_hit(HITId=h["HITId"], ExpireAt=datetime(2015, 1, 1)) h["dispose_exception"] = e non_disposed_hits.append(h) return non_disposed_hits @@ -764,9 +730,7 @@ def try_prerun_cleanup(db: "MephistoDB", requester_name: str) -> None: Try to see if there are any outstanding HITS for the given requester, and LOUDLY WARN if there are any, allowing the user to run a cleanup in-line. """ - cleanups_path = os.path.join( - DEFAULT_CONFIG_FOLDER, "mturk_requesters_last_cleanups.json" - ) + cleanups_path = os.path.join(DEFAULT_CONFIG_FOLDER, "mturk_requesters_last_cleanups.json") last_cleanup_times = {} if os.path.exists(cleanups_path): with open(cleanups_path) as cleanups_file: @@ -787,10 +751,7 @@ def try_prerun_cleanup(db: "MephistoDB", requester_name: str) -> None: client = requester._get_client(requester._requester_name) def hit_is_broken(hit: Dict[str, Any]) -> bool: - return ( - hit["NumberOfAssignmentsCompleted"] == 0 - and hit["HITStatus"] != "Reviewable" - ) + return hit["NumberOfAssignmentsCompleted"] == 0 and hit["HITStatus"] != "Reviewable" query_time = time.time() outstanding_hit_types = get_outstanding_hits(client) @@ -836,9 +797,7 @@ def hit_is_broken(hit: Dict[str, Any]) -> bool: print(f"HIT COUNT: {hit_count}") should_clear = "" while not (should_clear.startswith("y") or should_clear.startswith("n")): - should_clear = input( - "Should we cleanup this hit type? (y)es or (n)o: " "\n>> " - ).lower() + should_clear = input("Should we cleanup this hit type? (y)es or (n)o: " "\n>> ").lower() if should_clear.startswith("y"): hits_to_dispose += broken_hit_types[hit_type] confirm_string += ( diff --git a/mephisto/abstractions/providers/mturk/mturk_worker.py b/mephisto/abstractions/providers/mturk/mturk_worker.py index 8c8287bfa..22954c8e7 100644 --- a/mephisto/abstractions/providers/mturk/mturk_worker.py +++ b/mephisto/abstractions/providers/mturk/mturk_worker.py @@ -50,9 +50,7 @@ def __init__( _used_new_call: bool = False, ): super().__init__(db, db_id, row=row, _used_new_call=_used_new_call) - self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider( - self.PROVIDER_TYPE - ) + self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider(self.PROVIDER_TYPE) self._worker_name = self.worker_name # sandbox workers use a different name @classmethod @@ -62,13 +60,9 @@ def get_from_mturk_worker_id( """Get the MTurkWorker from the given worker_id""" if cls.PROVIDER_TYPE != PROVIDER_TYPE: mturk_worker_id += "_sandbox" - workers = db.find_workers( - worker_name=mturk_worker_id, provider_type=cls.PROVIDER_TYPE - ) + workers = db.find_workers(worker_name=mturk_worker_id, provider_type=cls.PROVIDER_TYPE) if len(workers) == 0: - logger.warning( - f"Could not find a Mephisto Worker for mturk_id {mturk_worker_id}" - ) + logger.warning(f"Could not find a Mephisto Worker for mturk_id {mturk_worker_id}") return None return cast("MTurkWorker", workers[0]) @@ -81,9 +75,7 @@ def _get_client(self, requester_name: str) -> Any: """ return self.datastore.get_client_for_requester(requester_name) - def grant_crowd_qualification( - self, qualification_name: str, value: int = 1 - ) -> None: + def grant_crowd_qualification(self, qualification_name: str, value: int = 1) -> None: """ Grant a qualification by the given name to this worker. Check the local MTurk db to find the matching MTurk qualification to grant, and pass @@ -93,30 +85,20 @@ def grant_crowd_qualification( requester to associate that qualification with by using the FIRST requester of the given account type (either `mturk` or `mturk_sandbox`) """ - mturk_qual_details = self.datastore.get_qualification_mapping( - qualification_name - ) + mturk_qual_details = self.datastore.get_qualification_mapping(qualification_name) if mturk_qual_details is not None: requester = Requester.get(self.db, mturk_qual_details["requester_id"]) qualification_id = mturk_qual_details["mturk_qualification_id"] else: - target_type = ( - "mturk_sandbox" if qualification_name.endswith("sandbox") else "mturk" - ) + target_type = "mturk_sandbox" if qualification_name.endswith("sandbox") else "mturk" requester = self.db.find_requesters(provider_type=target_type)[-1] assert isinstance( requester, MTurkRequester ), "find_requesters must return mturk requester for given provider types" - qualification_id = requester._create_new_mturk_qualification( - qualification_name - ) - assert isinstance( - requester, MTurkRequester - ), "Must be an MTurk requester for MTurk quals" + qualification_id = requester._create_new_mturk_qualification(qualification_name) + assert isinstance(requester, MTurkRequester), "Must be an MTurk requester for MTurk quals" client = self._get_client(requester._requester_name) - give_worker_qualification( - client, self.get_mturk_worker_id(), qualification_id, value - ) + give_worker_qualification(client, self.get_mturk_worker_id(), qualification_id, value) return None def revoke_crowd_qualification(self, qualification_name: str) -> None: @@ -125,9 +107,7 @@ def revoke_crowd_qualification(self, qualification_name: str) -> None: MTurk db to find the matching MTurk qualification to revoke, pass if no such qualification exists. """ - mturk_qual_details = self.datastore.get_qualification_mapping( - qualification_name - ) + mturk_qual_details = self.datastore.get_qualification_mapping(qualification_name) if mturk_qual_details is None: logger.error( f"No locally stored MTurk qualification to revoke for name {qualification_name}" @@ -135,14 +115,10 @@ def revoke_crowd_qualification(self, qualification_name: str) -> None: return None requester = Requester.get(self.db, mturk_qual_details["requester_id"]) - assert isinstance( - requester, MTurkRequester - ), "Must be an MTurk requester from MTurk quals" + assert isinstance(requester, MTurkRequester), "Must be an MTurk requester from MTurk quals" client = self._get_client(requester._requester_name) qualification_id = mturk_qual_details["mturk_qualification_id"] - remove_worker_qualification( - client, self.get_mturk_worker_id(), qualification_id - ) + remove_worker_qualification(client, self.get_mturk_worker_id(), qualification_id) return None def bonus_worker( @@ -155,15 +131,11 @@ def bonus_worker( return False, "bonusing via compensation tasks not yet available" unit = cast("MTurkUnit", unit) - requester = cast( - "MTurkRequester", unit.get_assignment().get_task_run().get_requester() - ) + requester = cast("MTurkRequester", unit.get_assignment().get_task_run().get_requester()) client = self._get_client(requester._requester_name) mturk_assignment_id = unit.get_mturk_assignment_id() assert mturk_assignment_id is not None, "Cannot bonus for a unit with no agent" - pay_bonus( - client, self._worker_name, amount, mturk_assignment_id, reason, str(uuid4()) - ) + pay_bonus(client, self._worker_name, amount, mturk_assignment_id, reason, str(uuid4())) return True, "" def block_worker( diff --git a/mephisto/abstractions/providers/mturk/utils/script_utils.py b/mephisto/abstractions/providers/mturk/utils/script_utils.py index 7836b2119..417db8ed0 100644 --- a/mephisto/abstractions/providers/mturk/utils/script_utils.py +++ b/mephisto/abstractions/providers/mturk/utils/script_utils.py @@ -51,9 +51,7 @@ def direct_assign_qual_mturk_workers( mturk_client = requester._get_client(requester._requester_name) for worker_id in tqdm(worker_list): try: - give_worker_qualification( - mturk_client, worker_id, qualification_id, value=1 - ) + give_worker_qualification(mturk_client, worker_id, qualification_id, value=1) except Exception as e: logging.exception( f'Failed to give worker with ID: "{worker_id}" qualification with error: {e}. Skipping.' diff --git a/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_agent.py b/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_agent.py index bdcd2e4c9..6fd3d624d 100644 --- a/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_agent.py +++ b/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_agent.py @@ -32,12 +32,8 @@ def _get_client(self) -> Any: Get an mturk client for usage with mturk_utils for this agent """ unit = self.get_unit() - requester: "SandboxMTurkRequester" = cast( - "SandboxMTurkRequester", unit.get_requester() - ) - return self.datastore.get_sandbox_client_for_requester( - requester._requester_name - ) + requester: "SandboxMTurkRequester" = cast("SandboxMTurkRequester", unit.get_requester()) + return self.datastore.get_sandbox_client_for_requester(requester._requester_name) @staticmethod def new(db: "MephistoDB", worker: "Worker", unit: "Unit") -> "Agent": diff --git a/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_requester.py b/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_requester.py index d1cd6580f..fce982cc0 100644 --- a/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_requester.py +++ b/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_requester.py @@ -29,9 +29,7 @@ def __init__( _used_new_call: bool = False, ): super().__init__(db, db_id, row=row, _used_new_call=_used_new_call) - self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider( - self.PROVIDER_TYPE - ) + self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider(self.PROVIDER_TYPE) # Use _requester_name to preserve sandbox behavior which # utilizes a different requester_name assert self.requester_name.endswith( @@ -58,6 +56,4 @@ def is_sandbox(cls) -> bool: def new(db: "MephistoDB", requester_name: str) -> "Requester": if not requester_name.endswith("_sandbox"): requester_name += "_sandbox" - return SandboxMTurkRequester._register_requester( - db, requester_name, PROVIDER_TYPE - ) + return SandboxMTurkRequester._register_requester(db, requester_name, PROVIDER_TYPE) diff --git a/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_unit.py b/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_unit.py index eb150fb6c..fc20dad11 100644 --- a/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_unit.py +++ b/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_unit.py @@ -34,10 +34,6 @@ def _get_client(self, requester_name: str) -> Any: return self.datastore.get_sandbox_client_for_requester(requester_name) @staticmethod - def new( - db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float - ) -> "Unit": + def new(db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float) -> "Unit": """Create a Unit for the given assignment""" - return SandboxMTurkUnit._register_unit( - db, assignment, index, pay_amount, PROVIDER_TYPE - ) + return SandboxMTurkUnit._register_unit(db, assignment, index, pay_amount, PROVIDER_TYPE) diff --git a/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_worker.py b/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_worker.py index d483eca08..e851c02f8 100644 --- a/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_worker.py +++ b/mephisto/abstractions/providers/mturk_sandbox/sandbox_mturk_worker.py @@ -31,15 +31,11 @@ def __init__( _used_new_call: bool = False, ): super().__init__(db, db_id, row=row, _used_new_call=_used_new_call) - self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider( - self.PROVIDER_TYPE - ) + self.datastore: "MTurkDatastore" = self.db.get_datastore_for_provider(self.PROVIDER_TYPE) # sandbox workers use a different name self._worker_name = self.worker_name[:-8] - def grant_crowd_qualification( - self, qualification_name: str, value: int = 1 - ) -> None: + def grant_crowd_qualification(self, qualification_name: str, value: int = 1) -> None: """ Grant a qualification by the given name to this worker. Check the local MTurk db to find the matching MTurk qualification to grant, and pass diff --git a/mephisto/abstractions/providers/prolific/api/base_api_resource.py b/mephisto/abstractions/providers/prolific/api/base_api_resource.py index af30f52df..85ec0bcde 100644 --- a/mephisto/abstractions/providers/prolific/api/base_api_resource.py +++ b/mephisto/abstractions/providers/prolific/api/base_api_resource.py @@ -70,9 +70,11 @@ def _base_request( url = urljoin(BASE_URL, api_endpoint) headers = headers or {} - headers.update({ - "Authorization": f"Token {api_key}", - }) + headers.update( + { + "Authorization": f"Token {api_key}", + } + ) logger.debug(f"{log_prefix} {method} {url}. Params: {params}") @@ -91,10 +93,7 @@ def _base_request( raise ProlificException("Invalid HTTP method.") response.raise_for_status() - if ( - response.status_code == status.HTTP_204_NO_CONTENT and - not response.content - ): + if response.status_code == status.HTTP_204_NO_CONTENT and not response.content: result = None else: result = response.json() @@ -104,9 +103,7 @@ def _base_request( return result except requests.exceptions.HTTPError as err: - logger.error( - f"{log_prefix} Request error: {err}. Response text: `{err.response.text}`" - ) + logger.error(f"{log_prefix} Request error: {err}. Response text: `{err.response.text}`") if err.response.status_code == status.HTTP_401_UNAUTHORIZED: raise ProlificAuthenticationError diff --git a/mephisto/abstractions/providers/prolific/api/bonuses.py b/mephisto/abstractions/providers/prolific/api/bonuses.py index d76f49514..635c166bd 100644 --- a/mephisto/abstractions/providers/prolific/api/bonuses.py +++ b/mephisto/abstractions/providers/prolific/api/bonuses.py @@ -10,8 +10,8 @@ class Bonuses(BaseAPIResource): - set_up_api_endpoint = 'submissions/bonus-payments/' - pay_api_endpoint = 'bulk-bonus-payments/{id}/pay/' + set_up_api_endpoint = "submissions/bonus-payments/" + pay_api_endpoint = "bulk-bonus-payments/{id}/pay/" @classmethod def set_up(cls, study_id: str, csv_bonuses: str) -> BonusPayments: diff --git a/mephisto/abstractions/providers/prolific/api/constants.py b/mephisto/abstractions/providers/prolific/api/constants.py index 96fed3714..1adebbcae 100644 --- a/mephisto/abstractions/providers/prolific/api/constants.py +++ b/mephisto/abstractions/providers/prolific/api/constants.py @@ -37,6 +37,7 @@ class StudyStatus: Study statuses explained https://docs.prolific.co/docs/api-docs/public/#tag/Studies/The-study-object """ + UNPUBLISHED = "UNPUBLISHED" ACTIVE = "ACTIVE" SCHEDULED = "SCHEDULED" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/base_model.py b/mephisto/abstractions/providers/prolific/api/data_models/base_model.py index 3b13dd02b..d4c64306d 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/base_model.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/base_model.py @@ -11,14 +11,14 @@ class BaseModel: id: str schema = { - 'type': 'object', - 'properties': { - 'id': {'type': 'string'}, + "type": "object", + "properties": { + "id": {"type": "string"}, }, } required_schema_fields = [] - id_field_name = 'id' + id_field_name = "id" def __init__(self, **data): self.__dict__ = data @@ -27,13 +27,13 @@ def __str__(self) -> str: return f'{self.__class__.__name__} {getattr(self, "id", "")}' def __repr__(self) -> str: - return f'<{self.__str__()}>' + return f"<{self.__str__()}>" def validate(self, check_required_fields: bool = True): schema = dict(self.schema) if check_required_fields: - schema['required'] = self.required_schema_fields + schema["required"] = self.required_schema_fields validate(instance=self.__dict__, schema=schema) diff --git a/mephisto/abstractions/providers/prolific/api/data_models/bonus_payments.py b/mephisto/abstractions/providers/prolific/api/data_models/bonus_payments.py index 525b81ae9..5b635926b 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/bonus_payments.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/bonus_payments.py @@ -13,6 +13,7 @@ class BonusPayments(BaseModel): More about Bonuses: https://docs.prolific.co/docs/api-docs/public/#tag/Bonuses """ + amount: Union[int, float] fees: Union[int, float] id: str @@ -21,11 +22,11 @@ class BonusPayments(BaseModel): vat: Union[int, float] schema = { - 'type': 'object', - 'properties': { - 'id': {'type': 'string'}, + "type": "object", + "properties": { + "id": {"type": "string"}, }, } def __str__(self) -> str: - return f'{self.__class__.__name__} {self.id}' + return f"{self.__class__.__name__} {self.id}" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/eligibility_requirement.py b/mephisto/abstractions/providers/prolific/api/data_models/eligibility_requirement.py index 37c48ffba..416ccdd85 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/eligibility_requirement.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/eligibility_requirement.py @@ -15,6 +15,7 @@ class EligibilityRequirement(BaseModel): More about Eligibility Requirements: https://docs.prolific.co/docs/api-docs/public/#tag/Requirements/Requirements-object """ + _cls: str attributes: List[dict] category: str @@ -28,49 +29,49 @@ class EligibilityRequirement(BaseModel): type: str schema = { - 'type': 'object', - 'properties': { - '_cls': {'type': 'string'}, - 'attributes': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'index': {'type': 'number'}, - 'label': {'type': 'string'}, - 'name': {'type': 'string'}, - 'value': {'type': 'boolean'}, + "type": "object", + "properties": { + "_cls": {"type": "string"}, + "attributes": { + "type": "array", + "items": { + "type": "object", + "properties": { + "index": {"type": "number"}, + "label": {"type": "string"}, + "name": {"type": "string"}, + "value": {"type": "boolean"}, }, }, }, - 'category': {'type': 'string'}, - 'details_display': {'type': 'string'}, - 'id': {'type': 'string'}, - 'order': {'type': 'number'}, - 'query': { - 'type': 'object', - 'properties': { - 'description': {'type': 'string'}, - 'help_text': {'type': 'string'}, - 'id': {'type': 'string'}, - 'is_new': {'type': 'boolean'}, - 'participant_help_text': {'type': 'string'}, - 'question': {'type': 'string'}, - 'researcher_help_text': {'type': 'string'}, - 'tags': {'type': 'array'}, - 'title': {'type': 'string'}, + "category": {"type": "string"}, + "details_display": {"type": "string"}, + "id": {"type": "string"}, + "order": {"type": "number"}, + "query": { + "type": "object", + "properties": { + "description": {"type": "string"}, + "help_text": {"type": "string"}, + "id": {"type": "string"}, + "is_new": {"type": "boolean"}, + "participant_help_text": {"type": "string"}, + "question": {"type": "string"}, + "researcher_help_text": {"type": "string"}, + "tags": {"type": "array"}, + "title": {"type": "string"}, }, }, - 'recommended': {'type': 'boolean'}, - 'requirement_type': {'type': 'string'}, - 'subcategory': {'type': ['string', 'null']}, - 'type': {'type': 'string'}, + "recommended": {"type": "boolean"}, + "requirement_type": {"type": "string"}, + "subcategory": {"type": ["string", "null"]}, + "type": {"type": "string"}, }, } def __init__(self, **data): super().__init__(**data) - setattr(self, 'id', data.get('query', {}).get('id')) + setattr(self, "id", data.get("query", {}).get("id")) def __str__(self) -> str: - return f'{self.__class__.__name__} {self._cls} {self.id}' + return f"{self.__class__.__name__} {self._cls} {self.id}" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/message.py b/mephisto/abstractions/providers/prolific/api/data_models/message.py index 3787eb44b..fe966352e 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/message.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/message.py @@ -14,6 +14,7 @@ class Message(BaseModel): More about Messages: https://docs.prolific.co/docs/api-docs/public/#tag/Messages """ + body: str channel_id: Optional[str] recipient_id: str @@ -22,22 +23,22 @@ class Message(BaseModel): type: str schema = { - 'type' : 'object', - 'properties' : { - 'body': {'type' : 'string'}, - 'channel_id': {'type' : ['string', 'null']}, - 'recipient_id': {'type' : 'string'}, - 'sender_id': {'type' : 'string'}, - 'sent_at': {'type' : 'string'}, - 'type': {'type' : 'string'}, + "type": "object", + "properties": { + "body": {"type": "string"}, + "channel_id": {"type": ["string", "null"]}, + "recipient_id": {"type": "string"}, + "sender_id": {"type": "string"}, + "sent_at": {"type": "string"}, + "type": {"type": "string"}, }, } required_schema_fields = [ - 'body', - 'recipient_id', - 'study_id', + "body", + "recipient_id", + "study_id", ] def __str__(self) -> str: - return f'{self.__class__.__name__} {self.sender_id}: {self.body}' + return f"{self.__class__.__name__} {self.sender_id}: {self.body}" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/participant.py b/mephisto/abstractions/providers/prolific/api/data_models/participant.py index b0283cc49..1256835e8 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/participant.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/participant.py @@ -12,14 +12,14 @@ class Participant(BaseModel): datetime_created: str schema = { - 'type': 'object', - 'properties': { - 'participant_id': {'type': 'string'}, - 'datetime_created': {'type': 'string'}, + "type": "object", + "properties": { + "participant_id": {"type": "string"}, + "datetime_created": {"type": "string"}, }, } - id_field_name = 'participant_id' + id_field_name = "participant_id" def __str__(self) -> str: - return f'{self.__class__.__name__} {self.participant_id}' + return f"{self.__class__.__name__} {self.participant_id}" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/participant_group.py b/mephisto/abstractions/providers/prolific/api/data_models/participant_group.py index 1beb41983..13d839776 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/participant_group.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/participant_group.py @@ -15,6 +15,7 @@ class ParticipantGroup(BaseModel): More about Participant Groups: https://docs.prolific.co/docs/api-docs/public/#tag/Participant-Groups """ + id: str name: str project_id: str @@ -22,22 +23,22 @@ class ParticipantGroup(BaseModel): feeder_studies: List[Dict] schema = { - 'type': 'object', - 'properties': { - 'id': {'type': 'string'}, - 'project_id': {'type': 'string'}, - 'name': {'type': 'string'}, - 'participant_count': {'type': 'number'}, - 'feeder_studies': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'id': {'type': 'string'}, - 'name': {'type': 'string'}, - 'internal_name': {'type': 'string'}, - 'status': {'type': 'string'}, - 'completion_codes': {'type': 'array'}, + "type": "object", + "properties": { + "id": {"type": "string"}, + "project_id": {"type": "string"}, + "name": {"type": "string"}, + "participant_count": {"type": "number"}, + "feeder_studies": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "name": {"type": "string"}, + "internal_name": {"type": "string"}, + "status": {"type": "string"}, + "completion_codes": {"type": "array"}, }, }, }, @@ -45,9 +46,9 @@ class ParticipantGroup(BaseModel): } required_schema_fields = [ - 'project_id', - 'name', + "project_id", + "name", ] def __str__(self) -> str: - return f'{self.__class__.__name__} {self.id} {self.name}' + return f"{self.__class__.__name__} {self.id} {self.name}" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/project.py b/mephisto/abstractions/providers/prolific/api/data_models/project.py index 98222c3b6..a395a25a4 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/project.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/project.py @@ -19,6 +19,7 @@ class Project(BaseModel): More about Projects: https://docs.prolific.co/docs/api-docs/public/#tag/Projects """ + id: str title: str description: str @@ -29,34 +30,34 @@ class Project(BaseModel): naivety_distribution_rate: Optional[Union[Decimal, float]] schema = { - 'type': 'object', - 'properties': { - 'id': {'type': 'string'}, - 'title': {'type': 'string'}, - 'description': {'type': 'string'}, - 'owner': {'type': 'string'}, - 'users': { - 'type': 'array', - 'items': User.relation_user_schema, + "type": "object", + "properties": { + "id": {"type": "string"}, + "title": {"type": "string"}, + "description": {"type": "string"}, + "owner": {"type": "string"}, + "users": { + "type": "array", + "items": User.relation_user_schema, }, - 'studies': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'id': {'type': 'string'}, - 'name': {'type': 'string'}, + "studies": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "name": {"type": "string"}, }, - } + }, }, - 'workspace': {'type': 'string'}, - 'naivety_distribution_rate': {'type': ['number', 'null']}, + "workspace": {"type": "string"}, + "naivety_distribution_rate": {"type": ["number", "null"]}, }, } required_schema_fields = [ - 'title', + "title", ] def __str__(self) -> str: - return f'{self.__class__.__name__} {self.id} {self.title}' + return f"{self.__class__.__name__} {self.id} {self.title}" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/study.py b/mephisto/abstractions/providers/prolific/api/data_models/study.py index 834b76759..e5f2e66a2 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/study.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/study.py @@ -16,6 +16,7 @@ class Study(BaseModel): More about Studies: https://docs.prolific.co/docs/api-docs/public/#tag/Studies """ + alternative_completion_codes: List average_reward_per_hour: Union[int, float] average_reward_per_hour_without_adjustment: Union[int, float] @@ -83,111 +84,111 @@ class Study(BaseModel): workspace: str schema = { - 'type': 'object', - 'properties': { - 'alternative_completion_codes': {'type': 'array'}, - 'average_reward_per_hour': {'type': 'number'}, - 'average_reward_per_hour_without_adjustment': {'type': 'number'}, - 'average_time_taken': {'type': 'number'}, - 'average_time_taken_seconds': {'type': 'number'}, - 'can_be_reallocated': {'type': 'boolean'}, - 'completion_code': {'type': 'string'}, - 'completion_code_action': {'type': ['string', 'null']}, - 'completion_codes': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'code': {'type': 'string'}, - 'code_type': {'type': 'string'}, - 'actions': {'type': 'array'}, + "type": "object", + "properties": { + "alternative_completion_codes": {"type": "array"}, + "average_reward_per_hour": {"type": "number"}, + "average_reward_per_hour_without_adjustment": {"type": "number"}, + "average_time_taken": {"type": "number"}, + "average_time_taken_seconds": {"type": "number"}, + "can_be_reallocated": {"type": "boolean"}, + "completion_code": {"type": "string"}, + "completion_code_action": {"type": ["string", "null"]}, + "completion_codes": { + "type": "array", + "items": { + "type": "object", + "properties": { + "code": {"type": "string"}, + "code_type": {"type": "string"}, + "actions": {"type": "array"}, }, - 'required': [ - 'code', - 'code_type', + "required": [ + "code", + "code_type", ], }, }, - 'completion_option': { - 'type': ['string', 'null'], - 'items': {'enum': ['url', 'code']}, + "completion_option": { + "type": ["string", "null"], + "items": {"enum": ["url", "code"]}, }, - 'currency_code': {'type': ['string', 'null']}, - 'date_created': {'type': 'string'}, - 'description': {'type': 'string'}, - 'device_compatibility': { - 'type': 'array', - 'items': {'enum': ['desktop', 'tablet', 'mobile']}, + "currency_code": {"type": ["string", "null"]}, + "date_created": {"type": "string"}, + "description": {"type": "string"}, + "device_compatibility": { + "type": "array", + "items": {"enum": ["desktop", "tablet", "mobile"]}, }, - 'discount_from_coupons': {'type': 'number'}, - 'eligibility_requirements': {'type': 'array'}, - 'eligible_participant_count': {'type': 'number'}, - 'estimated_completion_time': {'type': 'number'}, - 'estimated_reward_per_hour': {'type': 'number'}, - 'external_app': {'type': 'string'}, - 'external_id': {'type': 'string'}, - 'external_study_url': {'type': 'string'}, - 'failed_attention_code': {'type': ['string', 'null']}, - 'fees_per_submission': {'type': 'number'}, - 'fees_percentage': {'type': 'number'}, - 'has_had_adjustment': {'type': 'boolean'}, - 'id': {'type': 'string'}, - 'internal_name': {'type': 'string'}, - 'is_reallocated': {'type': 'boolean'}, - 'is_underpaying': {'type': ['boolean', 'null']}, - 'last_email_update_sent_datetime': {'type': ['string', 'null']}, - 'maximum_allowed_time': {'type': 'number'}, - 'metadata': {'type': ['string', 'object', 'number', 'null']}, - 'minimum_reward_per_hour': {'type': 'number'}, - 'naivety_distribution_rate': {'type': ['number', 'null']}, - 'name': {'type': 'string'}, - 'number_of_submissions': {'type': 'number'}, - 'peripheral_requirements': {'type': 'array'}, - 'places_taken': {'type': 'number'}, - 'privacy_notice': {'type': 'string'}, - 'progress_percentage': {'type': 'number'}, - 'project': {'type': 'string'}, - 'prolific_id_option': { - 'type': ['string', 'null'], - 'items': {'enum': ['question', 'url_parameters', 'not_required']}, + "discount_from_coupons": {"type": "number"}, + "eligibility_requirements": {"type": "array"}, + "eligible_participant_count": {"type": "number"}, + "estimated_completion_time": {"type": "number"}, + "estimated_reward_per_hour": {"type": "number"}, + "external_app": {"type": "string"}, + "external_id": {"type": "string"}, + "external_study_url": {"type": "string"}, + "failed_attention_code": {"type": ["string", "null"]}, + "fees_per_submission": {"type": "number"}, + "fees_percentage": {"type": "number"}, + "has_had_adjustment": {"type": "boolean"}, + "id": {"type": "string"}, + "internal_name": {"type": "string"}, + "is_reallocated": {"type": "boolean"}, + "is_underpaying": {"type": ["boolean", "null"]}, + "last_email_update_sent_datetime": {"type": ["string", "null"]}, + "maximum_allowed_time": {"type": "number"}, + "metadata": {"type": ["string", "object", "number", "null"]}, + "minimum_reward_per_hour": {"type": "number"}, + "naivety_distribution_rate": {"type": ["number", "null"]}, + "name": {"type": "string"}, + "number_of_submissions": {"type": "number"}, + "peripheral_requirements": {"type": "array"}, + "places_taken": {"type": "number"}, + "privacy_notice": {"type": "string"}, + "progress_percentage": {"type": "number"}, + "project": {"type": "string"}, + "prolific_id_option": { + "type": ["string", "null"], + "items": {"enum": ["question", "url_parameters", "not_required"]}, }, - 'publish_at': {'type': ['string', 'null']}, - 'published_at': {'type': ['string', 'null']}, - 'publisher': {'type': ['string', 'null']}, - 'quota_requirements': {'type': ['array', 'null']}, - 'reallocated_places': {'type': 'number'}, - 'receipt': {'type': ['string', 'null']}, - 'representative_sample': {'type': ['string', 'null']}, - 'representative_sample_fee': {'type': 'number'}, - 'researcher': {'type': 'object'}, - 'reward': {'type': 'number'}, - 'reward_level': {'type': 'object'}, - 'service_margin_percentage': {'type': 'string'}, - 'share_id': {'type': ['string', 'null']}, - 'stars_remaining': {'type': 'number'}, - 'status': {'type': 'string'}, - 'study_type': {'type': 'string'}, - 'submissions_config': {'type': 'object'}, - 'total_available_places': {'type': 'number'}, - 'total_cost': {'type': 'number'}, - 'total_participant_pool': {'type': 'number'}, - 'vat_percentage': {'type': 'number'}, - 'workspace': {'type': 'string'}, + "publish_at": {"type": ["string", "null"]}, + "published_at": {"type": ["string", "null"]}, + "publisher": {"type": ["string", "null"]}, + "quota_requirements": {"type": ["array", "null"]}, + "reallocated_places": {"type": "number"}, + "receipt": {"type": ["string", "null"]}, + "representative_sample": {"type": ["string", "null"]}, + "representative_sample_fee": {"type": "number"}, + "researcher": {"type": "object"}, + "reward": {"type": "number"}, + "reward_level": {"type": "object"}, + "service_margin_percentage": {"type": "string"}, + "share_id": {"type": ["string", "null"]}, + "stars_remaining": {"type": "number"}, + "status": {"type": "string"}, + "study_type": {"type": "string"}, + "submissions_config": {"type": "object"}, + "total_available_places": {"type": "number"}, + "total_cost": {"type": "number"}, + "total_participant_pool": {"type": "number"}, + "vat_percentage": {"type": "number"}, + "workspace": {"type": "string"}, }, } required_schema_fields = [ - 'name', - 'description', - 'external_study_url', - 'prolific_id_option', - 'completion_option', - 'completion_codes', - 'total_available_places', - 'estimated_completion_time', - 'reward', - 'eligibility_requirements', + "name", + "description", + "external_study_url", + "prolific_id_option", + "completion_option", + "completion_codes", + "total_available_places", + "estimated_completion_time", + "reward", + "eligibility_requirements", ] def __str__(self) -> str: - return f'{self.__class__.__name__} {self.id} {self.name}' + return f"{self.__class__.__name__} {self.id} {self.name}" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/submission.py b/mephisto/abstractions/providers/prolific/api/data_models/submission.py index 89cc05d8d..721278543 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/submission.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/submission.py @@ -14,6 +14,7 @@ class Submission(BaseModel): More about Submissions: https://docs.prolific.co/docs/api-docs/public/#tag/Submissions """ + completed_at: str entered_code: str id: str @@ -23,20 +24,20 @@ class Submission(BaseModel): study_id: str schema = { - 'type': 'object', - 'properties': { - 'completed_at': {'type': 'string'}, - 'entered_code': {'type': 'string'}, - 'id': {'type': 'string'}, - 'participant': {'type': 'string'}, - 'started_at': {'type': 'string'}, - 'status': {'type': 'string'}, - 'study_id': {'type': 'string'}, + "type": "object", + "properties": { + "completed_at": {"type": "string"}, + "entered_code": {"type": "string"}, + "id": {"type": "string"}, + "participant": {"type": "string"}, + "started_at": {"type": "string"}, + "status": {"type": "string"}, + "study_id": {"type": "string"}, }, } def __str__(self) -> str: - return f'{self.__class__.__name__} {self.id}' + return f"{self.__class__.__name__} {self.id}" class ListSubmission(BaseModel): @@ -46,6 +47,7 @@ class ListSubmission(BaseModel): More about Submissions: https://docs.prolific.co/docs/api-docs/public/#tag/Submissions """ + completed_at: str has_siblings: bool entered_code: str @@ -62,24 +64,24 @@ class ListSubmission(BaseModel): time_taken: int schema = { - 'type': 'object', - 'properties': { - 'completed_at': {'type': 'string'}, - 'has_siblings': {'type': 'boolean'}, - 'entered_code': {'type': 'string'}, - 'id': {'type': 'string'}, - 'ip': {'type': 'string'}, - 'is_complete': {'type': 'boolean'}, - 'participant_id': {'type': 'string'}, - 'return_requested': {'type': ['string', 'null']}, - 'reward': {'type': 'number'}, - 'started_at': {'type': 'string'}, - 'status': {'type': 'string'}, - 'strata': {'type': 'object'}, - 'study_code': {'type': 'string'}, - 'time_taken': {'type': 'number'}, + "type": "object", + "properties": { + "completed_at": {"type": "string"}, + "has_siblings": {"type": "boolean"}, + "entered_code": {"type": "string"}, + "id": {"type": "string"}, + "ip": {"type": "string"}, + "is_complete": {"type": "boolean"}, + "participant_id": {"type": "string"}, + "return_requested": {"type": ["string", "null"]}, + "reward": {"type": "number"}, + "started_at": {"type": "string"}, + "status": {"type": "string"}, + "strata": {"type": "object"}, + "study_code": {"type": "string"}, + "time_taken": {"type": "number"}, }, } def __str__(self) -> str: - return f'{self.__class__.__name__} {self.id}' + return f"{self.__class__.__name__} {self.id}" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/user.py b/mephisto/abstractions/providers/prolific/api/data_models/user.py index a675d700b..bcb91771e 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/user.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/user.py @@ -17,6 +17,7 @@ class User(BaseModel): More about Users: https://docs.prolific.co/docs/api-docs/public/#tag/Users """ + address: Optional[str] available_balance: Optional[Union[int, float]] balance: Optional[Union[int, float]] @@ -70,80 +71,80 @@ class User(BaseModel): vat_percentage: Optional[Union[int, float]] schema = { - 'type': 'object', - 'properties': { - 'address': {'type': ['string', 'null']}, - 'available_balance': {'type': ['number', 'null']}, - 'balance': {'type': ['number', 'null']}, - 'balance_breakdown': {'type': ['object', 'null']}, - 'beta_tester': {'type': 'boolean'}, - 'billing_address': {'type': ['string', 'null']}, - 'can_cashout_enabled': {'type': 'boolean'}, - 'can_contact_support_enabled': {'type': 'boolean'}, - 'can_instant_cashout_enabled': {'type': 'boolean'}, - 'can_oidc_login': {'type': 'boolean'}, - 'can_oidc_login_enabled': {'type': 'boolean'}, - 'can_run_pilot_study_enabled': {'type': 'boolean'}, - 'can_topup_3d': {'type': 'boolean'}, - 'country': {'type': ['string', 'null']}, - 'currency_code': {'type': 'string'}, - 'current_project_id': {'type': ['string', 'null']}, - 'date_joined': {'type': 'string'}, - 'datetime_created': {'type': 'string'}, - 'email': { - 'type': 'string', - 'pattern': EMAIL_FORMAT, + "type": "object", + "properties": { + "address": {"type": ["string", "null"]}, + "available_balance": {"type": ["number", "null"]}, + "balance": {"type": ["number", "null"]}, + "balance_breakdown": {"type": ["object", "null"]}, + "beta_tester": {"type": "boolean"}, + "billing_address": {"type": ["string", "null"]}, + "can_cashout_enabled": {"type": "boolean"}, + "can_contact_support_enabled": {"type": "boolean"}, + "can_instant_cashout_enabled": {"type": "boolean"}, + "can_oidc_login": {"type": "boolean"}, + "can_oidc_login_enabled": {"type": "boolean"}, + "can_run_pilot_study_enabled": {"type": "boolean"}, + "can_topup_3d": {"type": "boolean"}, + "country": {"type": ["string", "null"]}, + "currency_code": {"type": "string"}, + "current_project_id": {"type": ["string", "null"]}, + "date_joined": {"type": "string"}, + "datetime_created": {"type": "string"}, + "email": { + "type": "string", + "pattern": EMAIL_FORMAT, }, - 'email_preferences': {'type': 'object'}, - 'experimental_group': {'type': ['number', 'null']}, - 'fees_per_submission': {'type': ['number', 'null']}, - 'fees_percentage': {'type': ['number', 'null']}, - 'first_name': {'type': 'string'}, - 'has_accepted_survey_builder_terms': {'type': 'boolean'}, - 'has_answered_vat_number': {'type': 'boolean'}, - 'has_password': {'type': 'boolean'}, - 'id': {'type': 'string'}, - 'invoice_usage_enabled': {'type': 'boolean'}, - 'is_email_verified': {'type': 'boolean'}, - 'is_staff': {'type': 'boolean'}, - 'last_login': {'type': ['string', 'null']}, - 'last_name': {'type': 'string'}, - 'minimum_reward_per_hour': {'type': ['number', 'null']}, - 'name': {'type': 'string'}, - 'needs_to_confirm_US_state': {'type': 'boolean'}, - 'on_hold': {'type': 'boolean'}, - 'privacy_policy': {'type': 'boolean'}, - 'redeemable_referral_coupon': {'type': ['string', 'null']}, - 'referral_incentive': {'type': 'object'}, - 'referral_url': {'type': ['string', 'null']}, - 'representative_sample_credits': {'type': ['number', 'null']}, - 'roles': {'type': 'array'}, - 'service_margin_percentage': {'type': ['number', 'null']}, - 'status': {'type': 'string'}, - 'terms_and_conditions': {'type': 'boolean'}, - 'topups_over_referral_threshold': {'type': 'boolean'}, - 'user_type': {'type': 'string'}, - 'username': {'type': 'string'}, - 'vat_number': {'type': ['number', 'null']}, - 'vat_percentage': {'type': ['number', 'null']}, + "email_preferences": {"type": "object"}, + "experimental_group": {"type": ["number", "null"]}, + "fees_per_submission": {"type": ["number", "null"]}, + "fees_percentage": {"type": ["number", "null"]}, + "first_name": {"type": "string"}, + "has_accepted_survey_builder_terms": {"type": "boolean"}, + "has_answered_vat_number": {"type": "boolean"}, + "has_password": {"type": "boolean"}, + "id": {"type": "string"}, + "invoice_usage_enabled": {"type": "boolean"}, + "is_email_verified": {"type": "boolean"}, + "is_staff": {"type": "boolean"}, + "last_login": {"type": ["string", "null"]}, + "last_name": {"type": "string"}, + "minimum_reward_per_hour": {"type": ["number", "null"]}, + "name": {"type": "string"}, + "needs_to_confirm_US_state": {"type": "boolean"}, + "on_hold": {"type": "boolean"}, + "privacy_policy": {"type": "boolean"}, + "redeemable_referral_coupon": {"type": ["string", "null"]}, + "referral_incentive": {"type": "object"}, + "referral_url": {"type": ["string", "null"]}, + "representative_sample_credits": {"type": ["number", "null"]}, + "roles": {"type": "array"}, + "service_margin_percentage": {"type": ["number", "null"]}, + "status": {"type": "string"}, + "terms_and_conditions": {"type": "boolean"}, + "topups_over_referral_threshold": {"type": "boolean"}, + "user_type": {"type": "string"}, + "username": {"type": "string"}, + "vat_number": {"type": ["number", "null"]}, + "vat_percentage": {"type": ["number", "null"]}, }, } relation_user_schema = { - 'type': 'object', - 'properties': { - 'id': {'type': 'string'}, - 'name': {'type': 'string'}, - 'email': { - 'type': 'string', - 'pattern': EMAIL_FORMAT, + "type": "object", + "properties": { + "id": {"type": "string"}, + "name": {"type": "string"}, + "email": { + "type": "string", + "pattern": EMAIL_FORMAT, }, - 'roles': {'type': 'array'}, + "roles": {"type": "array"}, }, - 'required': [ - 'id', - ] + "required": [ + "id", + ], } def __str__(self) -> str: - return f'{self.__class__.__name__} {self.id} {self.email}' + return f"{self.__class__.__name__} {self.id} {self.email}" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/workspace.py b/mephisto/abstractions/providers/prolific/api/data_models/workspace.py index 24ca8062a..5e7f949b7 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/workspace.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/workspace.py @@ -19,6 +19,7 @@ class Workspace(BaseModel): More about Workspaces: https://docs.prolific.co/docs/api-docs/public/#tag/Workspaces """ + description: str id: str naivety_distribution_rate: Optional[Union[Decimal, float]] @@ -29,41 +30,41 @@ class Workspace(BaseModel): wallet: str schema = { - 'type': 'object', - 'properties': { - 'id': {'type': 'string'}, - 'title': {'type': 'string'}, - 'description': {'type': 'string'}, - 'owner': {'type': 'string'}, - 'users': { - 'type': 'array', - 'items': User.relation_user_schema, + "type": "object", + "properties": { + "id": {"type": "string"}, + "title": {"type": "string"}, + "description": {"type": "string"}, + "owner": {"type": "string"}, + "users": { + "type": "array", + "items": User.relation_user_schema, }, - 'projects': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'id': {'type': 'string'}, - 'title': {'type': 'string'}, - 'description': {'type': 'string'}, - 'owner': {'type': 'string'}, - 'users': { - 'type': 'array', - 'items': User.relation_user_schema, + "projects": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "title": {"type": "string"}, + "description": {"type": "string"}, + "owner": {"type": "string"}, + "users": { + "type": "array", + "items": User.relation_user_schema, }, - 'naivety_distribution_rate': {'type': ['number', 'null']}, + "naivety_distribution_rate": {"type": ["number", "null"]}, }, }, }, - 'wallet': {'type': 'string'}, - 'naivety_distribution_rate': {'type': ['number', 'null']}, + "wallet": {"type": "string"}, + "naivety_distribution_rate": {"type": ["number", "null"]}, }, } required_schema_fields = [ - 'title', + "title", ] def __str__(self) -> str: - return f'{self.__class__.__name__} {self.id} {self.title}' + return f"{self.__class__.__name__} {self.id} {self.title}" diff --git a/mephisto/abstractions/providers/prolific/api/data_models/workspace_balance.py b/mephisto/abstractions/providers/prolific/api/data_models/workspace_balance.py index c9fe0b90e..a31949b81 100644 --- a/mephisto/abstractions/providers/prolific/api/data_models/workspace_balance.py +++ b/mephisto/abstractions/providers/prolific/api/data_models/workspace_balance.py @@ -16,6 +16,7 @@ class WorkspaceBalance(BaseModel): More about Workspaces: https://docs.prolific.co/docs/api-docs/public/#tag/Workspaces """ + currency_code: str total_balance: Union[Decimal, float, int] balance_breakdown: Dict[str, Union[Decimal, float, int]] @@ -23,29 +24,29 @@ class WorkspaceBalance(BaseModel): available_balance_breakdown: Dict[str, Union[Decimal, float, int]] schema = { - 'type': 'object', - 'properties': { - 'currency_code': {'type': 'string'}, - 'total_balance': {'type': 'number'}, - 'balance_breakdown': { - 'type': 'object', - 'properties': { - 'rewards': {'type': 'number'}, - 'fees': {'type': 'number'}, - 'vat': {'type': 'number'}, + "type": "object", + "properties": { + "currency_code": {"type": "string"}, + "total_balance": {"type": "number"}, + "balance_breakdown": { + "type": "object", + "properties": { + "rewards": {"type": "number"}, + "fees": {"type": "number"}, + "vat": {"type": "number"}, }, }, - 'available_balance': {'type': 'number'}, - 'available_balance_breakdown': { - 'type': 'object', - 'properties': { - 'rewards': {'type': 'number'}, - 'fees': {'type': 'number'}, - 'vat': {'type': 'number'}, + "available_balance": {"type": "number"}, + "available_balance_breakdown": { + "type": "object", + "properties": { + "rewards": {"type": "number"}, + "fees": {"type": "number"}, + "vat": {"type": "number"}, }, }, }, } def __str__(self) -> str: - return f'{self.__class__.__name__} {self.total_balance} {self.currency_code}' + return f"{self.__class__.__name__} {self.total_balance} {self.currency_code}" diff --git a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/age_range_eligibility_requirement.py b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/age_range_eligibility_requirement.py index 6fa9fdeab..71535a59f 100644 --- a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/age_range_eligibility_requirement.py +++ b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/age_range_eligibility_requirement.py @@ -1,5 +1,5 @@ from mephisto.abstractions.providers.prolific.api.constants import ( - ELIGIBILITY_REQUIREMENT_AGE_RANGE_QUESTION_ID + ELIGIBILITY_REQUIREMENT_AGE_RANGE_QUESTION_ID, ) from .base_eligibility_requirement import BaseEligibilityRequirement @@ -8,8 +8,9 @@ class AgeRangeEligibilityRequirement(BaseEligibilityRequirement): """ Details https://docs.prolific.co/docs/api-docs/public/#tag/Requirements/Requirements-object """ - name = 'AgeRangeEligibilityRequirement' - prolific_cls_name = f'web.eligibility.models.{name}' + + name = "AgeRangeEligibilityRequirement" + prolific_cls_name = f"web.eligibility.models.{name}" def __init__(self, min_age: int, max_age: int): self.min_age = min_age @@ -20,6 +21,6 @@ def to_prolific_dict(self) -> dict: # HACK: Hardcoded Question IDs (Prolific doesn't have a better way for now) # TODO (#1008): Make this dynamic as soon as possible - prolific_dict['query'] = dict(id=ELIGIBILITY_REQUIREMENT_AGE_RANGE_QUESTION_ID) + prolific_dict["query"] = dict(id=ELIGIBILITY_REQUIREMENT_AGE_RANGE_QUESTION_ID) return prolific_dict diff --git a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/approval_numbers_eligibility_requirement.py b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/approval_numbers_eligibility_requirement.py index 128bccccc..6dc5ea2be 100644 --- a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/approval_numbers_eligibility_requirement.py +++ b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/approval_numbers_eligibility_requirement.py @@ -7,8 +7,9 @@ class ApprovalNumbersEligibilityRequirement(BaseEligibilityRequirement): """ Details https://docs.prolific.co/docs/api-docs/public/#tag/Requirements/Requirements-object """ - name = 'ApprovalNumbersEligibilityRequirement' - prolific_cls_name = f'web.eligibility.models.{name}' + + name = "ApprovalNumbersEligibilityRequirement" + prolific_cls_name = f"web.eligibility.models.{name}" def __init__( self, @@ -17,4 +18,3 @@ def __init__( ): self.minimum_approvals = minimum_approvals self.maximum_approvals = maximum_approvals - diff --git a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/approval_rate_eligibility_requirement.py b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/approval_rate_eligibility_requirement.py index ba61f652b..52910756b 100644 --- a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/approval_rate_eligibility_requirement.py +++ b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/approval_rate_eligibility_requirement.py @@ -5,8 +5,9 @@ class ApprovalRateEligibilityRequirement(BaseEligibilityRequirement): """ Details https://docs.prolific.co/docs/api-docs/public/#tag/Requirements/Requirements-object """ - name = 'ApprovalRateEligibilityRequirement' - prolific_cls_name = f'web.eligibility.models.{name}' + + name = "ApprovalRateEligibilityRequirement" + prolific_cls_name = f"web.eligibility.models.{name}" def __init__(self, minimum_approval_rate: int, maximum_approval_rate: int): self.minimum_approval_rate = minimum_approval_rate diff --git a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/base_eligibility_requirement.py b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/base_eligibility_requirement.py index 238faf5a3..bd5ae83d6 100644 --- a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/base_eligibility_requirement.py +++ b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/base_eligibility_requirement.py @@ -25,12 +25,13 @@ def __init__(self, min_value: int, max_value: int): 3. In the code all these requirements will be converted to the Prolific format (see mephisto.abstractions.providers.prolific.prolific_utils._get_eligibility_requirements) """ + prolific_cls_name = None @classmethod def params(cls): params = list(inspect.signature(cls.__init__).parameters.keys()) - params.remove('self') + params.remove("self") return params def to_prolific_dict(self) -> dict: @@ -45,16 +46,18 @@ def to_prolific_dict(self) -> dict: param_value = list(param_value) if param_value: - prolific_dict['attributes'].append(dict( - name=param_name, - value=param_value, - )) + prolific_dict["attributes"].append( + dict( + name=param_name, + value=param_value, + ) + ) return prolific_dict def __str__(self) -> str: _str = self.__class__.__name__ for param_name in self.params(): - _str += f' {param_name}={getattr(self, param_name, None)}' + _str += f" {param_name}={getattr(self, param_name, None)}" return _str def __repr__(self) -> str: diff --git a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/custom_black_list_eligibility_requirement.py b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/custom_black_list_eligibility_requirement.py index dea869011..b3fb67c94 100644 --- a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/custom_black_list_eligibility_requirement.py +++ b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/custom_black_list_eligibility_requirement.py @@ -7,8 +7,9 @@ class CustomBlacklistEligibilityRequirement(BaseEligibilityRequirement): """ Details https://docs.prolific.co/docs/api-docs/public/#tag/Requirements/Requirements-object """ - name = 'CustomBlacklistEligibilityRequirement' - prolific_cls_name = f'web.eligibility.models.{name}' + + name = "CustomBlacklistEligibilityRequirement" + prolific_cls_name = f"web.eligibility.models.{name}" def __init__(self, black_list: List[str]): self.black_list = black_list diff --git a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/custom_white_list_eligibility_requirement.py b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/custom_white_list_eligibility_requirement.py index 71556c807..5c87e5bac 100644 --- a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/custom_white_list_eligibility_requirement.py +++ b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/custom_white_list_eligibility_requirement.py @@ -7,8 +7,9 @@ class CustomWhitelistEligibilityRequirement(BaseEligibilityRequirement): """ Details https://docs.prolific.co/docs/api-docs/public/#tag/Requirements/Requirements-object """ - name = 'CustomWhitelistEligibilityRequirement' - prolific_cls_name = f'web.eligibility.models.{name}' + + name = "CustomWhitelistEligibilityRequirement" + prolific_cls_name = f"web.eligibility.models.{name}" def __init__(self, white_list: List[str]): self.white_list = white_list diff --git a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/joined_before_eligibility_requirement.py b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/joined_before_eligibility_requirement.py index 14d9b9769..178e62170 100644 --- a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/joined_before_eligibility_requirement.py +++ b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/joined_before_eligibility_requirement.py @@ -5,8 +5,9 @@ class JoinedBeforeEligibilityRequirement(BaseEligibilityRequirement): """ Details https://docs.prolific.co/docs/api-docs/public/#tag/Requirements/Requirements-object """ - name = 'JoinedBeforeEligibilityRequirement' - prolific_cls_name = f'web.eligibility.models.{name}' + + name = "JoinedBeforeEligibilityRequirement" + prolific_cls_name = f"web.eligibility.models.{name}" def __init__(self, joined_before: str): self.joined_before = joined_before diff --git a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/participant_group_eligibility_requirement.py b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/participant_group_eligibility_requirement.py index 08135585e..d25e4346b 100644 --- a/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/participant_group_eligibility_requirement.py +++ b/mephisto/abstractions/providers/prolific/api/eligibility_requirement_classes/participant_group_eligibility_requirement.py @@ -5,8 +5,9 @@ class ParticipantGroupEligibilityRequirement(BaseEligibilityRequirement): """ Details https://docs.prolific.co/docs/api-docs/public/#tag/Requirements/Requirements-object """ - name = 'ParticipantGroupEligibilityRequirement' - prolific_cls_name = f'web.eligibility.models.{name}' + + name = "ParticipantGroupEligibilityRequirement" + prolific_cls_name = f"web.eligibility.models.{name}" def __init__(self, id: str): self.id = id @@ -15,9 +16,11 @@ def __init__(self, id: str): def to_prolific_dict(self) -> dict: prolific_dict = dict( _cls=self.prolific_cls_name, - attributes=[dict( - id=self.id, - value=self.value, - )], + attributes=[ + dict( + id=self.id, + value=self.value, + ) + ], ) return prolific_dict diff --git a/mephisto/abstractions/providers/prolific/api/eligibility_requirements.py b/mephisto/abstractions/providers/prolific/api/eligibility_requirements.py index 4f25eed3e..fa191bbc5 100644 --- a/mephisto/abstractions/providers/prolific/api/eligibility_requirements.py +++ b/mephisto/abstractions/providers/prolific/api/eligibility_requirements.py @@ -11,16 +11,16 @@ class EligibilityRequirements(BaseAPIResource): - list_api_endpoint = 'eligibility-requirements/' - count_api_endpoint = 'eligibility-count/' + list_api_endpoint = "eligibility-requirements/" + count_api_endpoint = "eligibility-count/" @classmethod def list(cls) -> List[EligibilityRequirement]: response_json = cls.get(cls.list_api_endpoint) - eligibility_requirements = [EligibilityRequirement(**s) for s in response_json['results']] + eligibility_requirements = [EligibilityRequirement(**s) for s in response_json["results"]] return eligibility_requirements @classmethod def count_participants(cls) -> int: response_json = cls.post(cls.count_api_endpoint) - return response_json['count'] + return response_json["count"] diff --git a/mephisto/abstractions/providers/prolific/api/exceptions.py b/mephisto/abstractions/providers/prolific/api/exceptions.py index b702a66c6..9033c03a2 100644 --- a/mephisto/abstractions/providers/prolific/api/exceptions.py +++ b/mephisto/abstractions/providers/prolific/api/exceptions.py @@ -10,19 +10,20 @@ class ProlificException(Exception): - """ Main Prolific exception. All other exceptions should be inherited from it """ - default_message: str = 'Prolific error' + """Main Prolific exception. All other exceptions should be inherited from it""" + + default_message: str = "Prolific error" def __init__(self, message: Optional[str] = None): self.message = message or self.default_message class ProlificAPIKeyError(ProlificException): - default_message = 'API key is missing.' + default_message = "API key is missing." class ProlificRequestError(ProlificException): - default_message = 'Request error.' + default_message = "Request error." status_code = status.HTTP_400_BAD_REQUEST def __init__(self, message: Optional[str] = None, status_code: Optional[int] = None): @@ -31,5 +32,5 @@ def __init__(self, message: Optional[str] = None, status_code: Optional[int] = N class ProlificAuthenticationError(ProlificRequestError): - default_message = 'Authentication was failed.' + default_message = "Authentication was failed." status_code = status.HTTP_401_UNAUTHORIZED diff --git a/mephisto/abstractions/providers/prolific/api/messages.py b/mephisto/abstractions/providers/prolific/api/messages.py index 766177251..a08ca442b 100644 --- a/mephisto/abstractions/providers/prolific/api/messages.py +++ b/mephisto/abstractions/providers/prolific/api/messages.py @@ -13,12 +13,14 @@ class Messages(BaseAPIResource): - list_api_endpoint = 'messages/' - retrieve_api_endpoint = 'messages/unread/' + list_api_endpoint = "messages/" + retrieve_api_endpoint = "messages/unread/" @classmethod def list( - cls, user_id: Optional[str] = None, created_after: Optional[datetime] = None, + cls, + user_id: Optional[str] = None, + created_after: Optional[datetime] = None, ) -> List[Message]: """ Get messages between you and another user or your messages with all users @@ -30,12 +32,12 @@ def list( """ endpoint = cls.list_api_endpoint if user_id: - endpoint += f'?user_id={user_id}' + endpoint += f"?user_id={user_id}" elif created_after: - endpoint += f'?created_after={created_after.isoformat()}' + endpoint += f"?created_after={created_after.isoformat()}" response_json = cls.get(endpoint) - messages = [Message(**s) for s in response_json['results']] + messages = [Message(**s) for s in response_json["results"]] return messages @classmethod @@ -46,7 +48,7 @@ def list_unread(cls) -> List[Message]: It does not mark those messages as read """ response_json = cls.get(cls.list_api_endpoint) - messages = [Message(**s) for s in response_json['results']] + messages = [Message(**s) for s in response_json["results"]] return messages @classmethod diff --git a/mephisto/abstractions/providers/prolific/api/participant_groups.py b/mephisto/abstractions/providers/prolific/api/participant_groups.py index e10b035ca..be2664977 100644 --- a/mephisto/abstractions/providers/prolific/api/participant_groups.py +++ b/mephisto/abstractions/providers/prolific/api/participant_groups.py @@ -14,14 +14,16 @@ class ParticipantGroups(BaseAPIResource): - list_api_endpoint = 'participant-groups/' - retrieve_api_endpoint = 'participant-groups/{id}/' - remove_api_endpoint = 'participant-groups/{id}/' - list_participants_for_group_api_endpoint = 'participant-groups/{id}/participants/' + list_api_endpoint = "participant-groups/" + retrieve_api_endpoint = "participant-groups/{id}/" + remove_api_endpoint = "participant-groups/{id}/" + list_participants_for_group_api_endpoint = "participant-groups/{id}/participants/" @classmethod def list( - cls, project_id: Optional[str] = None, is_active: bool = True, + cls, + project_id: Optional[str] = None, + is_active: bool = True, ) -> List[ParticipantGroup]: """ API docs for this endpoint: @@ -30,16 +32,16 @@ def list( """ params = {} if project_id: - params['project_id'] = project_id + params["project_id"] = project_id if is_active: - params['is_active'] = is_active + params["is_active"] = is_active endpoint = cls.list_api_endpoint if params: - endpoint += '?' + urllib.parse.urlencode(params) + endpoint += "?" + urllib.parse.urlencode(params) response_json = cls.get(endpoint) - participant_groups = [ParticipantGroup(**s) for s in response_json['results']] + participant_groups = [ParticipantGroup(**s) for s in response_json["results"]] return participant_groups @classmethod @@ -82,7 +84,7 @@ def list_participants_for_group(cls, id: str) -> List[Participant]: Participant-Groups/paths/~1api~1v1~1participant-groups~1%7Bid%7D~1participants~1/get """ response_json = cls.get(cls.list_participants_for_group_api_endpoint.format(id=id)) - participants = [Participant(**s) for s in response_json['results']] + participants = [Participant(**s) for s in response_json["results"]] return participants @classmethod @@ -95,12 +97,14 @@ def add_participants_to_group(cls, id: str, participant_ids: List[str]) -> List[ endpoint = cls.list_participants_for_group_api_endpoint.format(id=id) params = dict(participant_ids=participant_ids) response_json = cls.post(endpoint, params=params) - participants = [Participant(**s) for s in response_json['results']] + participants = [Participant(**s) for s in response_json["results"]] return participants @classmethod def remove_participants_from_group( - cls, id: str, participant_ids: List[str], + cls, + id: str, + participant_ids: List[str], ) -> List[Participant]: """ API docs for this endpoint: @@ -110,5 +114,5 @@ def remove_participants_from_group( endpoint = cls.list_participants_for_group_api_endpoint.format(id=id) params = dict(participant_ids=participant_ids) response_json = cls.delete(endpoint, params=params) - participants = [Participant(**s) for s in response_json['results']] + participants = [Participant(**s) for s in response_json["results"]] return participants diff --git a/mephisto/abstractions/providers/prolific/api/projects.py b/mephisto/abstractions/providers/prolific/api/projects.py index 05500fb34..953451180 100644 --- a/mephisto/abstractions/providers/prolific/api/projects.py +++ b/mephisto/abstractions/providers/prolific/api/projects.py @@ -11,20 +11,21 @@ class Projects(BaseAPIResource): - list_for_workspace_api_endpoint = 'workspaces/{workspace_id}/projects/' - retrieve_for_workspace_api_endpoint = 'workspaces/{workspace_id}/projects/{project_id}/' + list_for_workspace_api_endpoint = "workspaces/{workspace_id}/projects/" + retrieve_for_workspace_api_endpoint = "workspaces/{workspace_id}/projects/{project_id}/" @classmethod def list_for_workspace(cls, workspace_id: str) -> List[Project]: endpoint = cls.list_for_workspace_api_endpoint.format(workspace_id=workspace_id) response_json = cls.get(endpoint) - projects = [Project(**s) for s in response_json['results']] + projects = [Project(**s) for s in response_json["results"]] return projects @classmethod def retrieve_for_workspace(cls, workspace_id: str, project_id: str) -> Project: endpoint = cls.retrieve_for_workspace_api_endpoint.format( - workspace_id=workspace_id, project_id=project_id, + workspace_id=workspace_id, + project_id=project_id, ) response_json = cls.get(endpoint) return Project(**response_json) diff --git a/mephisto/abstractions/providers/prolific/api/studies.py b/mephisto/abstractions/providers/prolific/api/studies.py index 3fbef93d9..2a1b25510 100644 --- a/mephisto/abstractions/providers/prolific/api/studies.py +++ b/mephisto/abstractions/providers/prolific/api/studies.py @@ -13,14 +13,14 @@ class Studies(BaseAPIResource): - list_api_endpoint = 'studies/' - list_for_project_api_endpoint = 'projects/{project_id}/studies/' - retrieve_api_endpoint = 'studies/{id}/' - update_api_endpoint = 'studies/{id}/' - remove_api_endpoint = 'studies/{id}/' - publish_cost_api_endpoint = 'studies/{id}/transition/' - stop_cost_api_endpoint = 'studies/{id}/transition/' - calculate_cost_api_endpoint = 'study-cost-calculator/' + list_api_endpoint = "studies/" + list_for_project_api_endpoint = "projects/{project_id}/studies/" + retrieve_api_endpoint = "studies/{id}/" + update_api_endpoint = "studies/{id}/" + remove_api_endpoint = "studies/{id}/" + publish_cost_api_endpoint = "studies/{id}/transition/" + stop_cost_api_endpoint = "studies/{id}/transition/" + calculate_cost_api_endpoint = "study-cost-calculator/" @classmethod def list(cls) -> List[Study]: @@ -29,7 +29,7 @@ def list(cls) -> List[Study]: https://docs.prolific.co/docs/api-docs/public/#tag/Studies/paths/~1api~1v1~1studies~1/get """ response_json = cls.get(cls.list_api_endpoint) - studies = [Study(**s) for s in response_json['results']] + studies = [Study(**s) for s in response_json["results"]] return studies @classmethod @@ -41,7 +41,7 @@ def list_for_project(cls, project_id: str) -> List[Study]: """ endpoint = cls.list_for_project_api_endpoint.format(project_id=project_id) response_json = cls.get(endpoint) - studies = [Study(**s) for s in response_json['results']] + studies = [Study(**s) for s in response_json["results"]] return studies @classmethod @@ -116,7 +116,9 @@ def stop(cls, id: str) -> Study: @classmethod def calculate_cost( - cls, reward: Union[int, float], total_available_places: int, + cls, + reward: Union[int, float], + total_available_places: int, ) -> Union[int, float]: """ API docs for this endpoint: @@ -128,4 +130,4 @@ def calculate_cost( total_available_places=total_available_places, ) response_json = cls.post(cls.calculate_cost_api_endpoint, params=params) - return response_json['total_cost'] + return response_json["total_cost"] diff --git a/mephisto/abstractions/providers/prolific/api/submissions.py b/mephisto/abstractions/providers/prolific/api/submissions.py index 8aeb6cf73..b0b90770a 100644 --- a/mephisto/abstractions/providers/prolific/api/submissions.py +++ b/mephisto/abstractions/providers/prolific/api/submissions.py @@ -14,9 +14,9 @@ class Submissions(BaseAPIResource): - list_api_endpoint = 'submissions/' - retrieve_api_endpoint = 'submissions/{id}/' - change_status_api_endpoint = 'submissions/{id}/transition/' + list_api_endpoint = "submissions/" + retrieve_api_endpoint = "submissions/{id}/" + change_status_api_endpoint = "submissions/{id}/transition/" @classmethod def list(cls, study_id: Optional[str] = None) -> List[ListSubmission]: @@ -27,9 +27,9 @@ def list(cls, study_id: Optional[str] = None) -> List[ListSubmission]: """ endpoint = cls.list_api_endpoint if study_id: - endpoint = f'{endpoint}?study={study_id}' + endpoint = f"{endpoint}?study={study_id}" response_json = cls.get(endpoint) - submissions = [ListSubmission(**s) for s in response_json['results']] + submissions = [ListSubmission(**s) for s in response_json["results"]] return submissions @classmethod @@ -60,8 +60,8 @@ def _change_status( action=action, ) if reason_message: - params['message'] = reason_message - params['rejection_category'] = rejection_category, + params["message"] = reason_message + params["rejection_category"] = (rejection_category,) endpoint = cls.change_status_api_endpoint.format(id=id) response_json = cls.post(endpoint, params=params) diff --git a/mephisto/abstractions/providers/prolific/api/users.py b/mephisto/abstractions/providers/prolific/api/users.py index 18bc0c773..17257daf9 100644 --- a/mephisto/abstractions/providers/prolific/api/users.py +++ b/mephisto/abstractions/providers/prolific/api/users.py @@ -9,7 +9,7 @@ class Users(BaseAPIResource): - me_api_endpoint = 'users/me/' + me_api_endpoint = "users/me/" @classmethod def me(cls) -> User: diff --git a/mephisto/abstractions/providers/prolific/prolific_agent.py b/mephisto/abstractions/providers/prolific/prolific_agent.py index 35a9067df..4e08d1308 100644 --- a/mephisto/abstractions/providers/prolific/prolific_agent.py +++ b/mephisto/abstractions/providers/prolific/prolific_agent.py @@ -65,9 +65,7 @@ def __init__( def _get_client(self) -> ProlificClient: """Get a Prolific client""" - requester: "ProlificRequester" = cast( - "ProlificRequester", self.unit.get_requester() - ) + requester: "ProlificRequester" = cast("ProlificRequester", self.unit.get_requester()) return self.datastore.get_client_for_requester(requester.requester_name) @property @@ -92,9 +90,7 @@ def new_from_provider_data( f"Registering Prolific Submission in datastore from Prolific. Data: {provider_data}" ) - assert isinstance( - unit, ProlificUnit - ), "Can only register Prolific agents to Prolific units" + assert isinstance(unit, ProlificUnit), "Can only register Prolific agents to Prolific units" prolific_study_id = provider_data["prolific_study_id"] prolific_submission_id = provider_data["assignment_id"] @@ -109,16 +105,16 @@ def approve_work(self) -> None: logger.debug(f"{self.log_prefix}Approving work") if self.get_status() == AgentState.STATUS_APPROVED: - logger.info( - f"{self.log_prefix}Approving already approved agent {self}, skipping" - ) + logger.info(f"{self.log_prefix}Approving already approved agent {self}, skipping") return client = self._get_client() prolific_study_id = self.unit.get_prolific_study_id() worker_id = self.worker.get_prolific_participant_id() prolific_utils.approve_work( - client, study_id=prolific_study_id, worker_id=worker_id, + client, + study_id=prolific_study_id, + worker_id=worker_id, ) logger.debug( @@ -137,7 +133,9 @@ def soft_reject_work(self) -> None: prolific_study_id = self.unit.get_prolific_study_id() worker_id = self.worker.get_prolific_participant_id() prolific_utils.approve_work( - client, study_id=prolific_study_id, worker_id=worker_id, + client, + study_id=prolific_study_id, + worker_id=worker_id, ) logger.debug( @@ -163,7 +161,9 @@ def reject_work(self, reason) -> None: try: prolific_utils.reject_work( - client, study_id=prolific_study_id, worker_id=worker_id, + client, + study_id=prolific_study_id, + worker_id=worker_id, ) except ProlificException: logger.info( @@ -209,7 +209,7 @@ def get_status(self) -> str: # Get Submission from Prolific, records status datastore_unit = self.datastore.get_unit(unit_agent_pairing.db_id) - prolific_submission_id = datastore_unit['prolific_submission_id'] + prolific_submission_id = datastore_unit["prolific_submission_id"] prolific_submission = None if prolific_submission_id: prolific_submission = prolific_utils.get_submission(client, prolific_submission_id) @@ -229,7 +229,7 @@ def get_status(self) -> str: prolific_submission.status, ) if not provider_status: - raise Exception(f'Unexpected Submission status {prolific_submission.status}') + raise Exception(f"Unexpected Submission status {prolific_submission.status}") self.update_status(provider_status) diff --git a/mephisto/abstractions/providers/prolific/prolific_datastore.py b/mephisto/abstractions/providers/prolific/prolific_datastore.py index 912473793..58e53bbb9 100644 --- a/mephisto/abstractions/providers/prolific/prolific_datastore.py +++ b/mephisto/abstractions/providers/prolific/prolific_datastore.py @@ -322,7 +322,7 @@ def get_blocked_workers(self) -> List[dict]: return results def get_bloked_participant_ids(self) -> List[str]: - return [w['worker_id'] for w in self.get_blocked_workers()] + return [w["worker_id"] for w in self.get_blocked_workers()] def ensure_unit_exists(self, unit_id: str) -> None: """Create a record of this unit if it doesn't exist""" @@ -443,9 +443,7 @@ def get_client_for_requester(self, requester_name: str) -> ProlificClient: """ return self.get_session_for_requester(requester_name) - def get_qualification_mapping( - self, qualification_name: str - ) -> Optional[sqlite3.Row]: + def get_qualification_mapping(self, qualification_name: str) -> Optional[sqlite3.Row]: """Get the mapping between Mephisto qualifications and Prolific Participant Group""" with self.table_access_condition: conn = self._get_connection() @@ -514,9 +512,7 @@ def create_participant_group_mapping( ), "Cannot be none given is_unique_failure on insert" db_requester_id = db_qualification["requester_id"] - db_prolific_qualification_name = db_qualification[ - "prolific_participant_group_name" - ] + db_prolific_qualification_name = db_qualification["prolific_participant_group_name"] if db_requester_id != requester_id: logger.warning( @@ -536,7 +532,8 @@ def create_participant_group_mapping( raise e def delete_participant_groups_by_participant_group_ids( - self, participant_group_ids: List[str] = None, + self, + participant_group_ids: List[str] = None, ) -> None: """Delete participant_groups by Participant Group IDs""" if not participant_group_ids: @@ -545,11 +542,11 @@ def delete_participant_groups_by_participant_group_ids( with self.table_access_condition, self._get_connection() as conn: c = conn.cursor() - participant_group_ids_block = '' + participant_group_ids_block = "" if participant_group_ids: task_run_ids_str = ",".join([f'"{pgi}"' for pgi in participant_group_ids]) participant_group_ids_block = ( - f'AND prolific_participant_group_id IN ({task_run_ids_str})' + f"AND prolific_participant_group_id IN ({task_run_ids_str})" ) c.execute( @@ -589,14 +586,12 @@ def create_qualification_mapping( ), ) - def find_studies_by_status( - self, statuses: List[str], exclude: bool = False - ) -> List[dict]: + def find_studies_by_status(self, statuses: List[str], exclude: bool = False) -> List[dict]: """Find all studies having or excluding certain statuses""" if not statuses: return [] - logic_str = 'NOT' if exclude else '' + logic_str = "NOT" if exclude else "" statuses_str = ",".join([f'"{s}"' for s in statuses]) with self.table_access_condition, self._get_connection() as conn: @@ -611,18 +606,21 @@ def find_studies_by_status( return results def find_qualifications_for_running_studies( - self, qualification_ids: List[str], + self, + qualification_ids: List[str], ) -> List[dict]: """Find qualifications by Mephisto ids of qualifications for all incomplete studies""" if not qualification_ids: return [] running_studies = self.find_studies_by_status( - statuses=[StudyStatus.COMPLETED, StudyStatus.AWAITING_REVIEW], exclude=True, + statuses=[StudyStatus.COMPLETED, StudyStatus.AWAITING_REVIEW], + exclude=True, ) - task_run_ids = [s['task_run_id'] for s in running_studies] + task_run_ids = [s["task_run_id"] for s in running_studies] return self.find_qualifications_by_ids( - qualification_ids=qualification_ids, task_run_ids=task_run_ids, + qualification_ids=qualification_ids, + task_run_ids=task_run_ids, ) def find_qualifications_by_ids( @@ -637,18 +635,17 @@ def find_qualifications_by_ids( with self.table_access_condition, self._get_connection() as conn: c = conn.cursor() - qualification_ids_block = '' + qualification_ids_block = "" if qualification_ids: - qualification_ids_block = ' OR '.join( - 'qualification_ids LIKE \'%"' + str(_id) + '"%\'' - for _id in qualification_ids + qualification_ids_block = " OR ".join( + "qualification_ids LIKE '%\"" + str(_id) + "\"%'" for _id in qualification_ids ) - qualification_ids_block = f'({qualification_ids_block})' + qualification_ids_block = f"({qualification_ids_block})" - task_run_ids_block = '' + task_run_ids_block = "" if task_run_ids: task_run_ids_str = ",".join([f'"{tid}"' for tid in task_run_ids]) - task_run_ids_block = f'AND task_run_id IN ({task_run_ids_str})' + task_run_ids_block = f"AND task_run_id IN ({task_run_ids_str})" c.execute( f""" @@ -660,7 +657,8 @@ def find_qualifications_by_ids( return results def delete_qualifications_by_participant_group_ids( - self, participant_group_ids: List[str] = None, + self, + participant_group_ids: List[str] = None, ) -> None: """Delete qualifications by Participant Group IDs""" if not participant_group_ids: @@ -669,11 +667,11 @@ def delete_qualifications_by_participant_group_ids( with self.table_access_condition, self._get_connection() as conn: c = conn.cursor() - participant_group_ids_block = '' + participant_group_ids_block = "" if participant_group_ids: task_run_ids_str = ",".join([f'"{pgi}"' for pgi in participant_group_ids]) participant_group_ids_block = ( - f'AND prolific_participant_group_id IN ({task_run_ids_str})' + f"AND prolific_participant_group_id IN ({task_run_ids_str})" ) c.execute( diff --git a/mephisto/abstractions/providers/prolific/prolific_provider.py b/mephisto/abstractions/providers/prolific/prolific_provider.py index 7081f1d5b..ca0ee7564 100644 --- a/mephisto/abstractions/providers/prolific/prolific_provider.py +++ b/mephisto/abstractions/providers/prolific/prolific_provider.py @@ -171,10 +171,12 @@ def _get_client(self, requester_name: str) -> ProlificClient: return self.datastore.get_client_for_requester(requester_name) def _get_qualified_workers( - self, qualifications: List[QualificationType], bloked_participant_ids: List[str], + self, + qualifications: List[QualificationType], + bloked_participant_ids: List[str], ) -> List["Worker"]: qualified_workers = [] - workers: List[Worker] = self.db.find_workers(provider_type='prolific') + workers: List[Worker] = self.db.find_workers(provider_type="prolific") # `worker_name` is Prolific Participant ID in provider-specific datastore available_workers = [w for w in workers if w.worker_name not in bloked_participant_ids] @@ -191,7 +193,7 @@ def _create_participant_group_with_qualified_workers( workers_ids: List[str], prolific_project_id: str, ) -> ParticipantGroup: - participant_proup_name = f'PG {datetime.now(timezone.utc).isoformat()}' + participant_proup_name = f"PG {datetime.now(timezone.utc).isoformat()}" prolific_participant_group = prolific_utils.create_qualification( client, prolific_project_id, @@ -226,9 +228,7 @@ def setup_resources_for_task_run( config_dir = os.path.join(self.datastore.datastore_root, task_run_id) frame_height = ( - task_run.get_blueprint() - .get_frontend_args() - .get("frame_height", DEFAULT_FRAME_HEIGHT) + task_run.get_blueprint().get_frontend_args().get("frame_height", DEFAULT_FRAME_HEIGHT) ) # Mephisto qualifications @@ -236,7 +236,9 @@ def setup_resources_for_task_run( # Get provider-specific qualification from SharedState prolific_specific_qualifications = getattr( - shared_state, 'prolific_specific_qualifications', [], + shared_state, + "prolific_specific_qualifications", + [], ) # Update with ones from YAML config under `provider` title yaml_prolific_specific_qualifications = args.provider.prolific_eligibility_requirements @@ -249,11 +251,9 @@ def setup_resources_for_task_run( ) # Get Prolific specific data to create a task - prolific_workspace: Workspace = ( - prolific_utils.find_or_create_prolific_workspace( - client, - title=args.provider.prolific_workspace_name, - ) + prolific_workspace: Workspace = prolific_utils.find_or_create_prolific_workspace( + client, + title=args.provider.prolific_workspace_name, ) prolific_project: Project = prolific_utils.find_or_create_prolific_project( client, @@ -268,37 +268,39 @@ def setup_resources_for_task_run( if blocked_participant_ids: new_prolific_specific_qualifications = [] # Add empty Blacklist in case if there is not in state or config - blacklist_qualification = DictConfig(dict( - name=CustomBlacklistEligibilityRequirement.name, - black_list=[], - )) + blacklist_qualification = DictConfig( + dict( + name=CustomBlacklistEligibilityRequirement.name, + black_list=[], + ) + ) for prolific_specific_qualification in prolific_specific_qualifications: - name = prolific_specific_qualification['name'] + name = prolific_specific_qualification["name"] if name == CustomBlacklistEligibilityRequirement.name: blacklist_qualification = prolific_specific_qualification elif name == CustomWhitelistEligibilityRequirement.name: # Remove blocked Participat IDs from Whitelist Eligibility Requirement whitelist_qualification = prolific_specific_qualification - prev_value = whitelist_qualification['white_list'] - whitelist_qualification['white_list'] = [ + prev_value = whitelist_qualification["white_list"] + whitelist_qualification["white_list"] = [ p for p in prev_value if p not in blocked_participant_ids ] new_prolific_specific_qualifications.append(whitelist_qualification) elif name == ParticipantGroupEligibilityRequirement.name: # Remove blocked Participat IDs from Participant Group Eligibility Requirement client.ParticipantGroups.remove_participants_from_group( - id=prolific_specific_qualification['id'], + id=prolific_specific_qualification["id"], participant_ids=blocked_participant_ids, ) else: new_prolific_specific_qualifications.append(prolific_specific_qualification) # Set Blacklist Eligibility Requirement - blacklist_qualification['black_list'] = list(set( - blacklist_qualification['black_list'] + blocked_participant_ids - )) + blacklist_qualification["black_list"] = list( + set(blacklist_qualification["black_list"] + blocked_participant_ids) + ) new_prolific_specific_qualifications.append(blacklist_qualification) prolific_specific_qualifications = new_prolific_specific_qualifications @@ -308,24 +310,25 @@ def setup_resources_for_task_run( if qualified_workers: prolific_workers_ids = [w.worker_name for w in qualified_workers] # Create a new Participant Group - prolific_participant_group = ( - self._create_participant_group_with_qualified_workers( - client, - requester, - prolific_workers_ids, - prolific_project.id, - ) + prolific_participant_group = self._create_participant_group_with_qualified_workers( + client, + requester, + prolific_workers_ids, + prolific_project.id, ) # Add this Participant Group to Prolific-specific requirements - prolific_specific_qualifications.append({ - 'name': ParticipantGroupEligibilityRequirement.name, - 'id': prolific_participant_group.id, - }) + prolific_specific_qualifications.append( + { + "name": ParticipantGroupEligibilityRequirement.name, + "id": prolific_participant_group.id, + } + ) - qualification_names = [q['qualification_name'] for q in qualifications] + qualification_names = [q["qualification_name"] for q in qualifications] qualification_objs = self.db.find_qualifications() qualifications_ids = [ - q.db_id for q in qualification_objs + q.db_id + for q in qualification_objs if q.qualification_name in qualification_names ] self.datastore.create_qualification_mapping( @@ -370,9 +373,7 @@ def setup_resources_for_task_run( self.datastore.new_study( prolific_study_id=prolific_study.id, study_link=prolific_study.external_study_url, - duration_in_seconds=( - args.provider.prolific_estimated_completion_time_in_minutes * 60 - ), + duration_in_seconds=(args.provider.prolific_estimated_completion_time_in_minutes * 60), task_run_id=task_run_id, status=StudyStatus.ACTIVE, ) @@ -395,7 +396,7 @@ def cleanup_resources_from_task_run(self, task_run: "TaskRun", server_url: str) # Remove from Provider-specific datastore participant_group_ids = [ - i['prolific_participant_group_id'] for i in datastore_qualifications + i["prolific_participant_group_id"] for i in datastore_qualifications ] self.datastore.delete_qualifications_by_participant_group_ids( participant_group_ids=participant_group_ids, @@ -407,7 +408,8 @@ def cleanup_resources_from_task_run(self, task_run: "TaskRun", server_url: str) # Remove from Prolific for qualification in datastore_qualifications: prolific_utils.delete_qualification( - client, qualification['prolific_participant_group_id'], + client, + qualification["prolific_participant_group_id"], ) @classmethod @@ -430,7 +432,8 @@ def cleanup_qualification(self, qualification_name: str) -> None: client = requester._get_client(requester.requester_name) try: prolific_utils.delete_qualification( - client, mapping["prolific_participant_group_id"], + client, + mapping["prolific_participant_group_id"], ) except ProlificException: logger.exception("Could not delete qualification on Prolific") diff --git a/mephisto/abstractions/providers/prolific/prolific_requester.py b/mephisto/abstractions/providers/prolific/prolific_requester.py index a8df3d172..3cc1b18a5 100644 --- a/mephisto/abstractions/providers/prolific/prolific_requester.py +++ b/mephisto/abstractions/providers/prolific/prolific_requester.py @@ -89,7 +89,9 @@ def get_available_budget(self) -> float: return balance def create_new_qualification( - self, prolific_project_id: str, qualification_name: str, + self, + prolific_project_id: str, + qualification_name: str, ) -> ParticipantGroup: """ Create a new qualification (Prolific Participant Group) on Prolific @@ -98,14 +100,18 @@ def create_new_qualification( client = self._get_client(self.requester_name) _qualification_name = qualification_name qualification = prolific_utils.find_or_create_qualification( - client, prolific_project_id, qualification_name, + client, + prolific_project_id, + qualification_name, ) if qualification is None: # Try to append time to make the qualification unique _qualification_name = f"{qualification_name}_{time.time()}" qualification = prolific_utils.find_or_create_qualification( - client, prolific_project_id, _qualification_name, + client, + prolific_project_id, + _qualification_name, ) attempts = 0 @@ -113,7 +119,9 @@ def create_new_qualification( # Append something somewhat random _qualification_name = f"{qualification_name}_{str(uuid4())}" qualification = prolific_utils.find_or_create_qualification( - client, prolific_project_id, _qualification_name, + client, + prolific_project_id, + _qualification_name, ) attempts += 1 if attempts > MAX_QUALIFICATION_ATTEMPTS: diff --git a/mephisto/abstractions/providers/prolific/prolific_unit.py b/mephisto/abstractions/providers/prolific/prolific_unit.py index efccf381d..e57da341c 100644 --- a/mephisto/abstractions/providers/prolific/prolific_unit.py +++ b/mephisto/abstractions/providers/prolific/prolific_unit.py @@ -51,16 +51,16 @@ class ProlificUnit(Unit): def __init__( self, - db: 'MephistoDB', + db: "MephistoDB", db_id: str, row: Optional[Mapping[str, Any]] = None, _used_new_call: bool = False, ): super().__init__(db, db_id, row=row, _used_new_call=_used_new_call) - self.datastore: 'ProlificDatastore' = db.get_datastore_for_provider(PROVIDER_TYPE) + self.datastore: "ProlificDatastore" = db.get_datastore_for_provider(PROVIDER_TYPE) self._last_sync_time = 0.0 self._sync_study_mapping() - self.__requester: Optional['ProlificRequester'] = None + self.__requester: Optional["ProlificRequester"] = None def _get_client(self, requester_name: str) -> Any: """Get a Prolific client for usage with `prolific_utils`""" @@ -68,7 +68,7 @@ def _get_client(self, requester_name: str) -> Any: @property def log_prefix(self) -> str: - return f'[Unit {self.db_id}] ' + return f"[Unit {self.db_id}] " def _sync_study_mapping(self) -> None: """Sync with the datastore to see if any mappings have updated""" @@ -76,9 +76,9 @@ def _sync_study_mapping(self) -> None: return try: mapping = dict(self.datastore.get_study_mapping(self.db_id)) - self.prolific_study_id = mapping['prolific_study_id'] - self.prolific_submission_id = mapping.get('prolific_submission_id') - self.assignment_time_in_seconds = mapping.get('assignment_time_in_seconds') + self.prolific_study_id = mapping["prolific_study_id"] + self.prolific_submission_id = mapping.get("prolific_submission_id") + self.assignment_time_in_seconds = mapping.get("assignment_time_in_seconds") except IndexError: # Study does not appear to exist self.prolific_study_id = None @@ -90,7 +90,9 @@ def _sync_study_mapping(self) -> None: self._last_sync_time = time.monotonic() - 1 def register_from_provider_data( - self, prolific_study_id: str, prolific_submission_id: str, + self, + prolific_study_id: str, + prolific_submission_id: str, ) -> None: """Update the datastore and local information from this registration""" self.datastore.set_submission_for_unit( @@ -114,10 +116,10 @@ def get_prolific_study_id(self) -> Optional[str]: self._sync_study_mapping() return self.prolific_study_id - def get_requester(self) -> 'ProlificRequester': + def get_requester(self) -> "ProlificRequester": """Wrapper around regular Requester as this will be ProlificRequester""" if self.__requester is None: - self.__requester = cast('ProlificRequester', super().get_requester()) + self.__requester = cast("ProlificRequester", super().get_requester()) return self.__requester def get_status(self) -> str: @@ -138,7 +140,7 @@ def get_status(self) -> str: return self.db_status # Get API client - requester: 'ProlificRequester' = self.get_requester() + requester: "ProlificRequester" = self.get_requester() client = self._get_client(requester.requester_name) # time.sleep(2) # Prolific servers may take time to bring their data up-to-date @@ -155,12 +157,13 @@ def get_status(self) -> str: # Get Submission from Prolific, record status datastore_unit = self.datastore.get_unit(self.db_id) - prolific_submission_id = datastore_unit['prolific_submission_id'] + prolific_submission_id = datastore_unit["prolific_submission_id"] prolific_submission = None if prolific_submission_id: prolific_submission = prolific_utils.get_submission(client, prolific_submission_id) self.datastore.update_submission_status( - prolific_submission_id, prolific_submission.status, + prolific_submission_id, + prolific_submission.status, ) # Check Unit status @@ -183,8 +186,8 @@ def get_status(self) -> str: # Check for NULL worker_id to prevent accidental reversal of unit's progress if external_status != AssignmentState.LAUNCHED: logger.debug( - f'{self.log_prefix}Moving Unit {self.db_id} status from ' - f'`{external_status}` to `{AssignmentState.LAUNCHED}`' + f"{self.log_prefix}Moving Unit {self.db_id} status from " + f"`{external_status}` to `{AssignmentState.LAUNCHED}`" ) external_status = AssignmentState.LAUNCHED elif prolific_submission.status == SubmissionStatus.PROCESSING: @@ -195,7 +198,7 @@ def get_status(self) -> str: prolific_submission.status, ) if not external_status: - raise Exception(f'Unexpected Submission status {prolific_submission.status}') + raise Exception(f"Unexpected Submission status {prolific_submission.status}") if external_status != local_status: self.set_db_status(external_status) @@ -215,8 +218,8 @@ def set_db_status(self, status: str) -> None: datastore_task_run = self.datastore.get_run(task_run_id) self.datastore.set_available_places_for_run( run_id=task_run_id, - actual_available_places=datastore_task_run['actual_available_places'] - 1, - listed_available_places=datastore_task_run['listed_available_places'] - 1, + actual_available_places=datastore_task_run["actual_available_places"] - 1, + listed_available_places=datastore_task_run["listed_available_places"] - 1, ) def clear_assigned_agent(self) -> None: @@ -246,7 +249,7 @@ def get_pay_amount(self) -> float: Return the amount that this Unit is costing against the budget, calculating additional fees as relevant """ - logger.debug(f'{self.log_prefix}Getting pay amount') + logger.debug(f"{self.log_prefix}Getting pay amount") requester = self.get_requester() client = self._get_client(requester.requester_name) @@ -257,7 +260,7 @@ def get_pay_amount(self) -> float: # TODO: what value should go in here when we auto-increment `total_available_places`? total_available_places=1, ) - logger.debug(f'{self.log_prefix}Pay amount: {total_amount}') + logger.debug(f"{self.log_prefix}Pay amount: {total_amount}") return total_amount @@ -276,8 +279,8 @@ def launch(self, task_url: str) -> None: task_run_id = self.get_task_run().db_id datastore_task_run = self.datastore.get_run(task_run_id) - actual_available_places = datastore_task_run['actual_available_places'] - listed_available_places = datastore_task_run['listed_available_places'] + actual_available_places = datastore_task_run["actual_available_places"] + listed_available_places = datastore_task_run["listed_available_places"] provider_increment_needed = False if actual_available_places is None: @@ -303,7 +306,8 @@ def launch(self, task_url: str) -> None: requester = self.get_requester() client = self._get_client(requester.requester_name) prolific_utils.increase_total_available_places_for_study( - client, datastore_task_run['prolific_study_id'], + client, + datastore_task_run["prolific_study_id"], ) # Change DB status @@ -324,8 +328,8 @@ def expire(self) -> float: task_run = self.get_task_run() datastore_task_run = self.datastore.get_run(task_run.db_id) - actual_available_places = datastore_task_run['actual_available_places'] - listed_available_places = datastore_task_run['listed_available_places'] + actual_available_places = datastore_task_run["actual_available_places"] + listed_available_places = datastore_task_run["listed_available_places"] listed_places_decrement = 1 if task_run.get_is_completed() else 0 self.datastore.set_available_places_for_run( @@ -338,7 +342,7 @@ def expire(self) -> float: # If Mephisto has expired all its units, we force-stop Prolific Study requester = self.get_requester() client = self._get_client(requester.requester_name) - prolific_utils.stop_study(client, datastore_task_run['prolific_study_id']) + prolific_utils.stop_study(client, datastore_task_run["prolific_study_id"]) # Update status if status in [AssignmentState.EXPIRED, AssignmentState.COMPLETED]: @@ -350,7 +354,7 @@ def expire(self) -> float: # amount of time we granted for working on this assignment if self.assignment_time_in_seconds is not None: delay = self.assignment_time_in_seconds - logger.debug(f'{self.log_prefix}Expiring a unit that is ASSIGNED after delay {delay}') + logger.debug(f"{self.log_prefix}Expiring a unit that is ASSIGNED after delay {delay}") prolific_study_id = self.get_prolific_study_id() requester = self.get_requester() @@ -373,24 +377,22 @@ def is_expired(self) -> bool: return self.get_status() == AssignmentState.EXPIRED @staticmethod - def new( - db: 'MephistoDB', assignment: 'Assignment', index: int, pay_amount: float - ) -> 'Unit': + def new(db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float) -> "Unit": """Create a Unit for the given assignment""" unit = ProlificUnit._register_unit(db, assignment, index, pay_amount, PROVIDER_TYPE) # Write unit in provider-specific datastore - datastore: 'ProlificDatastore' = db.get_datastore_for_provider(PROVIDER_TYPE) + datastore: "ProlificDatastore" = db.get_datastore_for_provider(PROVIDER_TYPE) task_run_details = dict(datastore.get_run(assignment.task_run_id)) logger.debug( f'{ProlificUnit.log_prefix}Create Unit "{unit.db_id}". ' - f'Task Run datastore details: {task_run_details}' + f"Task Run datastore details: {task_run_details}" ) datastore.create_unit( unit_id=unit.db_id, run_id=assignment.task_run_id, - prolific_study_id=task_run_details['prolific_study_id'], + prolific_study_id=task_run_details["prolific_study_id"], ) - logger.debug(f'{ProlificUnit.log_prefix}Unit was created in datastore successfully!') + logger.debug(f"{ProlificUnit.log_prefix}Unit was created in datastore successfully!") return unit diff --git a/mephisto/abstractions/providers/prolific/prolific_utils.py b/mephisto/abstractions/providers/prolific/prolific_utils.py index 9dcf405d2..ca10a7090 100644 --- a/mephisto/abstractions/providers/prolific/prolific_utils.py +++ b/mephisto/abstractions/providers/prolific/prolific_utils.py @@ -120,19 +120,19 @@ def _convert_eligibility_requirements(value: List[dict]) -> List[dict]: cls_kwargs[param_name] = conf_eligibility_requirement[param_name] eligibility_requirements.append(cls(**cls_kwargs).to_prolific_dict()) except Exception: - logger.exception('Could not convert passed Eligibility Requirements') + logger.exception("Could not convert passed Eligibility Requirements") # Generate human-readable log what Eligibility Requirements and with what parameters # are available. available_classes = inspect.getmembers( - sys.modules[eligibility_requirement_classes.__name__], inspect.isclass, + sys.modules[eligibility_requirement_classes.__name__], + inspect.isclass, ) - log_classes_dicts = [{ - 'name': c[0], - **{p: '' for p in c[1].params()} - } for c in available_classes] + log_classes_dicts = [ + {"name": c[0], **{p: "" for p in c[1].params()}} for c in available_classes + ] logger.info( - f'Available Eligibility Requirements in short form for config:\n' + - '\n'.join([str(i) for i in log_classes_dicts]) + f"Available Eligibility Requirements in short form for config:\n" + + "\n".join([str(i) for i in log_classes_dicts]) ) raise @@ -152,9 +152,7 @@ def check_balance(client: ProlificClient, **kwargs) -> Union[float, int, None]: return None try: - workspace_balance: WorkspaceBalance = client.Workspaces.get_balance( - id=workspace.id - ) + workspace_balance: WorkspaceBalance = client.Workspaces.get_balance(id=workspace.id) except (ProlificException, ValidationError): logger.exception(f"Could not receive a workspace balance with {workspace.id=}") raise @@ -279,9 +277,7 @@ def _find_qualification( project_id=prolific_project_id, ) except (ProlificException, ValidationError): - logger.exception( - f'Could not receive a qualifications for project "{prolific_project_id}"' - ) + logger.exception(f'Could not receive a qualifications for project "{prolific_project_id}"') raise for qualification in qualifications: @@ -406,9 +402,7 @@ def compose_completion_codes(code_suffix: str) -> List[dict]: # Initially provide a random completion code during study completion_codes_random = compose_completion_codes(uuid.uuid4().hex[:5]) - logger.debug( - f"Initial completion codes for creating Study: {completion_codes_random}" - ) + logger.debug(f"Initial completion codes for creating Study: {completion_codes_random}") try: # TODO (#1008): Make sure that all parameters are correct @@ -434,9 +428,7 @@ def compose_completion_codes(code_suffix: str) -> List[dict]: # This code will be used to redirect worker to Prolific's "Submission Completed" page # (see `mephisto.abstractions.providers.prolific.wrap_crowd_source.handleSubmitToProvider`) completion_codes_with_study_id = compose_completion_codes(study.id) - logger.debug( - f"Final completion codes for updating Study: {completion_codes_with_study_id}" - ) + logger.debug(f"Final completion codes for updating Study: {completion_codes_with_study_id}") study: Study = client.Studies.update( id=study.id, completion_codes=completion_codes_with_study_id, @@ -599,9 +591,7 @@ def pay_bonus( workspace_name=task_run_config.provider.prolific_workspace_name, ): # Just in case if Prolific adds showing an available balance for an account - logger.debug( - "Cannot pay bonus. Reason: Insufficient funds in your Prolific account." - ) + logger.debug("Cannot pay bonus. Reason: Insufficient funds in your Prolific account.") return False # Unlike all other Prolific endpoints working with cents, this one requires dollars @@ -670,12 +660,14 @@ def unblock_worker( def is_worker_blocked( - client: ProlificClient, task_run_config: "DictConfig", worker_id: str, + client: ProlificClient, + task_run_config: "DictConfig", + worker_id: str, ) -> bool: - """ Determine if the given worker is blocked by this client + """Determine if the given worker is blocked by this client - TODO (#1008): do we even need to check with Prolific "Blocked Participants" group - (as opposed to out datastore)? Because it doesn't reflect Prolific's internal banning + TODO (#1008): do we even need to check with Prolific "Blocked Participants" group + (as opposed to out datastore)? Because it doesn't reflect Prolific's internal banning """ workspace = find_or_create_prolific_workspace( client, @@ -696,9 +688,7 @@ def is_worker_blocked( return False try: - participants: List[ - Participant - ] = client.ParticipantGroups.list_participants_for_group( + participants: List[Participant] = client.ParticipantGroups.list_participants_for_group( block_list_qualification.id, ) except (ProlificException, ValidationError): @@ -759,14 +749,14 @@ def get_submission(client: ProlificClient, submission_id: str) -> Submission: def approve_work( - client: ProlificClient, study_id: str, worker_id: str, + client: ProlificClient, + study_id: str, + worker_id: str, ) -> Union[Submission, None]: submission: ListSubmission = _find_submission(client, study_id, worker_id) if not submission: - logger.warning( - f'No submission found for study "{study_id}" and participant "{worker_id}"' - ) + logger.warning(f'No submission found for study "{study_id}" and participant "{worker_id}"') return None # TODO (#1008): Maybe we need to expand handling submission statuses @@ -788,14 +778,14 @@ def approve_work( def reject_work( - client: ProlificClient, study_id: str, worker_id: str, + client: ProlificClient, + study_id: str, + worker_id: str, ) -> Union[Submission, None]: submission: ListSubmission = _find_submission(client, study_id, worker_id) if not submission: - logger.warning( - f'No submission found for study "{study_id}" and participant "{worker_id}"' - ) + logger.warning(f'No submission found for study "{study_id}" and participant "{worker_id}"') return None # TODO (#1008): Maybe we need to expand handling submission statuses diff --git a/mephisto/abstractions/providers/prolific/prolific_worker.py b/mephisto/abstractions/providers/prolific/prolific_worker.py index 32b6d4a95..1df4b1ab1 100644 --- a/mephisto/abstractions/providers/prolific/prolific_worker.py +++ b/mephisto/abstractions/providers/prolific/prolific_worker.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 import json + # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. @@ -42,13 +43,13 @@ class ProlificWorker(Worker): def __init__( self, - db: 'MephistoDB', + db: "MephistoDB", db_id: str, row: Optional[Mapping[str, Any]] = None, _used_new_call: bool = False, ): super().__init__(db, db_id, row=row, _used_new_call=_used_new_call) - self.datastore: 'ProlificDatastore' = db.get_datastore_for_provider(PROVIDER_TYPE) + self.datastore: "ProlificDatastore" = db.get_datastore_for_provider(PROVIDER_TYPE) def _get_client(self, requester_name: str) -> Any: """Get a Prolific client for usage with `prolific_utils`""" @@ -56,26 +57,26 @@ def _get_client(self, requester_name: str) -> Any: @property def log_prefix(self) -> str: - return f'[Worker {self.db_id}] ' + return f"[Worker {self.db_id}] " def get_prolific_participant_id(self): return self.worker_name def bonus_worker( - self, amount: float, reason: str, unit: Optional['Unit'] = None + self, amount: float, reason: str, unit: Optional["Unit"] = None ) -> Tuple[bool, str]: """Bonus a worker for work any reason. Return success of bonus""" - logger.debug(f'{self.log_prefix}Paying bonuses') + logger.debug(f"{self.log_prefix}Paying bonuses") if unit is None: - return False, 'bonusing via compensation tasks not yet available' + return False, "bonusing via compensation tasks not yet available" - unit: 'ProlificUnit' = cast('ProlificUnit', unit) + unit: "ProlificUnit" = cast("ProlificUnit", unit) if unit is None: # TODO(WISH) soft block from all requesters? Maybe have the main requester soft block? return ( False, - 'Paying bonuses without a unit not yet supported for ProlificWorkers', + "Paying bonuses without a unit not yet supported for ProlificWorkers", ) task_run: TaskRun = unit.get_task_run() @@ -87,9 +88,9 @@ def bonus_worker( study_id = unit.get_prolific_study_id() logger.debug( - f'{self.log_prefix}' - f'Trying to pay bonuses to worker {participant_id} for Study {study_id}. ' - f'Amount: {amount}' + f"{self.log_prefix}" + f"Trying to pay bonuses to worker {participant_id} for Study {study_id}. " + f"Amount: {amount}" ) prolific_utils.pay_bonus( client, @@ -99,40 +100,40 @@ def bonus_worker( study_id=study_id, ) - logger.debug(f'{self.log_prefix}Bonuses have been paid successfully') + logger.debug(f"{self.log_prefix}Bonuses have been paid successfully") - return True, '' + return True, "" @staticmethod - def _get_first_task_run(requester: 'Requester') -> 'TaskRun': + def _get_first_task_run(requester: "Requester") -> "TaskRun": task_runs: List[TaskRun] = requester.get_task_runs() return task_runs[0] def block_worker( self, reason: str, - unit: Optional['Unit'] = None, - requester: Optional['Requester'] = None, + unit: Optional["Unit"] = None, + requester: Optional["Requester"] = None, ) -> Tuple[bool, str]: """Block this worker for a specified reason. Return success of block""" - logger.debug(f'{self.log_prefix}Blocking worker {self.worker_name}') + logger.debug(f"{self.log_prefix}Blocking worker {self.worker_name}") if not unit and not requester: # TODO(WISH) soft block from all requesters? Maybe have the main requester soft block? return ( False, - 'Blocking without a unit or requester not yet supported for ProlificWorkers', + "Blocking without a unit or requester not yet supported for ProlificWorkers", ) elif unit and not requester: task_run = unit.get_assignment().get_task_run() - requester: 'ProlificRequester' = cast('ProlificRequester', task_run.get_requester()) + requester: "ProlificRequester" = cast("ProlificRequester", task_run.get_requester()) else: task_run = self._get_first_task_run(requester) - logger.debug(f'{self.log_prefix}Task Run: {task_run}') + logger.debug(f"{self.log_prefix}Task Run: {task_run}") task_run_args = task_run.args - requester: 'ProlificRequester' = cast('ProlificRequester', requester) + requester: "ProlificRequester" = cast("ProlificRequester", requester) client = self._get_client(requester.requester_name) prolific_utils.block_worker(client, task_run_args, self.worker_name, reason) self.datastore.set_worker_blocked(self.worker_name, is_blocked=True) @@ -146,27 +147,29 @@ def block_worker( db_qualification_ids, ) prolific_participant_group_ids = [ - p['prolific_participant_group_id'] for p in prolific_qualifications + p["prolific_participant_group_id"] for p in prolific_qualifications ] for prolific_participant_group_id in prolific_participant_group_ids: prolific_utils.remove_worker_qualification( - client, self.worker_name, prolific_participant_group_id, + client, + self.worker_name, + prolific_participant_group_id, ) - logger.debug(f'{self.log_prefix}Worker {self.worker_name} blocked') + logger.debug(f"{self.log_prefix}Worker {self.worker_name} blocked") - return True, '' + return True, "" - def unblock_worker(self, reason: str, requester: 'Requester') -> Tuple[bool, str]: + def unblock_worker(self, reason: str, requester: "Requester") -> Tuple[bool, str]: """Unblock a blocked worker for the specified reason. Return success of unblock""" - logger.debug(f'{self.log_prefix}Unlocking worker {self.worker_name}') + logger.debug(f"{self.log_prefix}Unlocking worker {self.worker_name}") task_run = self._get_first_task_run(requester) - logger.debug(f'{self.log_prefix}Task Run: {task_run}') + logger.debug(f"{self.log_prefix}Task Run: {task_run}") task_run_args = task_run.args - requester = cast('ProlificRequester', requester) + requester = cast("ProlificRequester", requester) client = self._get_client(requester.requester_name) prolific_utils.unblock_worker(client, task_run_args, self.worker_name, reason) self.datastore.set_worker_blocked(self.worker_name, is_blocked=False) @@ -174,29 +177,31 @@ def unblock_worker(self, reason: str, requester: 'Requester') -> Tuple[bool, str # Include unblocked Worker into all Participant Groups for currently running Studies, # if he is qualified at the moment self._grant_crowd_qualifications(client) - logger.debug(f'{self.log_prefix}Worker {self.worker_name} unblocked') + logger.debug(f"{self.log_prefix}Worker {self.worker_name} unblocked") - return True, '' + return True, "" - def is_blocked(self, requester: 'Requester') -> bool: + def is_blocked(self, requester: "Requester") -> bool: """Determine if a worker is blocked""" task_run = self._get_first_task_run(requester) - requester = cast('ProlificRequester', requester) + requester = cast("ProlificRequester", requester) is_blocked = self.datastore.get_worker_blocked(self.get_prolific_participant_id()) logger.debug( - f'{self.log_prefix}' + f"{self.log_prefix}" f'Worker "{self.worker_name}" {is_blocked=} for Task Run "{task_run.db_id}"' ) return is_blocked - def is_eligible(self, task_run: 'TaskRun') -> bool: + def is_eligible(self, task_run: "TaskRun") -> bool: """Determine if this worker is eligible for the given task run""" return True def _grant_crowd_qualifications( - self, client: ProlificClient, qualification_name: Optional[str] = None, + self, + client: ProlificClient, + qualification_name: Optional[str] = None, ) -> None: """ Grant specified qualification if `qualification_name` is passed or @@ -206,7 +211,7 @@ def _grant_crowd_qualifications( is_blocked = self.datastore.get_worker_blocked(prolific_participant_id) if is_blocked: logger.debug( - f'{self.log_prefix}' + f"{self.log_prefix}" f'Worker is blocked. Cannot grant qualification "{qualification_name}"' ) return None @@ -220,7 +225,7 @@ def _grant_crowd_qualifications( db_qualification_ids, ) qualifications_groups = [ - (json.loads(i['json_qual_logic']), i['prolific_participant_group_id']) + (json.loads(i["json_qual_logic"]), i["prolific_participant_group_id"]) for i in prolific_qualifications ] @@ -228,27 +233,33 @@ def _grant_crowd_qualifications( if worker_is_qualified(self, qualifications): # Worker is still qualified or was upgraded, and so is eligible now prolific_utils.give_worker_qualification( - client, self.worker_name, prolific_participant_group_id, + client, + self.worker_name, + prolific_participant_group_id, ) else: # Worker is now not eligible for this Participant Group anymore prolific_utils.remove_worker_qualification( - client, self.worker_name, prolific_participant_group_id, + client, + self.worker_name, + prolific_participant_group_id, ) logger.debug( - f'{self.log_prefix}Crowd qualification {qualification_name} has been granted ' + f"{self.log_prefix}Crowd qualification {qualification_name} has been granted " f'for Prolific Participant "{prolific_participant_id}"' ) def grant_crowd_qualification( - self, qualification_name: Optional[str] = None, value: int = 1, + self, + qualification_name: Optional[str] = None, + value: int = 1, ) -> None: """Grant qualification by the given name to this worker""" - logger.debug(f'{self.log_prefix}Granting crowd qualification: {qualification_name}') + logger.debug(f"{self.log_prefix}Granting crowd qualification: {qualification_name}") requester = cast( - 'ProlificRequester', + "ProlificRequester", self.db.find_requesters(provider_type=self.provider_type)[-1], ) client = self._get_client(requester.requester_name) @@ -258,39 +269,39 @@ def grant_crowd_qualification( def revoke_crowd_qualification(self, qualification_name: str) -> None: """Revoke qualification by given name from this worker""" - logger.debug(f'{self.log_prefix}Revoking crowd qualification: {qualification_name}') + logger.debug(f"{self.log_prefix}Revoking crowd qualification: {qualification_name}") p_qualification_details = self.datastore.get_qualification_mapping(qualification_name) if p_qualification_details is None: logger.error( - f'{self.log_prefix}No locally stored Prolific qualification (Participant Groups) ' - f'to revoke for name {qualification_name}' + f"{self.log_prefix}No locally stored Prolific qualification (Participant Groups) " + f"to revoke for name {qualification_name}" ) return None - requester = Requester.get(self.db, p_qualification_details['requester_id']) + requester = Requester.get(self.db, p_qualification_details["requester_id"]) assert isinstance( requester, ProlificRequester - ), 'Must be an Prolific requester from Prolific qualifications' + ), "Must be an Prolific requester from Prolific qualifications" client = self._get_client(requester.requester_name) p_worker_id = self.get_prolific_participant_id() - p_qualification_id = p_qualification_details['prolific_participant_group_id'] + p_qualification_id = p_qualification_details["prolific_participant_group_id"] prolific_utils.remove_worker_qualification(client, p_worker_id, p_qualification_id) logger.debug( - f'{self.log_prefix}Crowd qualification {qualification_name} has been revoked ' + f"{self.log_prefix}Crowd qualification {qualification_name} has been revoked " f'for Prolific Participant "{p_worker_id}"' ) return None @staticmethod - def new(db: 'MephistoDB', worker_id: str) -> 'Worker': + def new(db: "MephistoDB", worker_id: str) -> "Worker": new_worker = ProlificWorker._register_worker(db, worker_id, PROVIDER_TYPE) # Save worker in provider-specific datastore - datastore: 'ProlificDatastore' = db.get_datastore_for_provider(PROVIDER_TYPE) + datastore: "ProlificDatastore" = db.get_datastore_for_provider(PROVIDER_TYPE) datastore.ensure_worker_exists(worker_id) return new_worker diff --git a/mephisto/abstractions/test/architect_tester.py b/mephisto/abstractions/test/architect_tester.py index afd295da3..0b8feb795 100644 --- a/mephisto/abstractions/test/architect_tester.py +++ b/mephisto/abstractions/test/architect_tester.py @@ -116,14 +116,10 @@ def test_init_architect(self) -> None: issubclass(self.ArchitectClass, Architect), "Implemented ArchitectClass does not extend Architect", ) - self.assertNotEqual( - self.ArchitectClass, Architect, "Can not use base Architect" - ) + self.assertNotEqual(self.ArchitectClass, Architect, "Can not use base Architect") arch_args = self.ArchitectClass.ArgsClass() args = OmegaConf.structured(MephistoConfig(architect=arch_args)) - architect = self.ArchitectClass( - self.db, args, EMPTY_STATE, self.task_run, self.build_dir - ) + architect = self.ArchitectClass(self.db, args, EMPTY_STATE, self.task_run, self.build_dir) def get_architect(self) -> Architect: """ diff --git a/mephisto/abstractions/test/blueprint_tester.py b/mephisto/abstractions/test/blueprint_tester.py index 449d10e84..b0a891523 100644 --- a/mephisto/abstractions/test/blueprint_tester.py +++ b/mephisto/abstractions/test/blueprint_tester.py @@ -67,9 +67,7 @@ def get_test_assignment(self) -> Assignment: """Create a test assignment for self.task_run using mock agents""" raise NotImplementedError() - def assignment_is_tracked( - self, task_runner: TaskRunner, assignment: Assignment - ) -> bool: + def assignment_is_tracked(self, task_runner: TaskRunner, assignment: Assignment) -> bool: """ Return whether or not this task is currently being tracked (run) by the given task runner. This should be false unless @@ -136,9 +134,7 @@ def test_ensure_valid_statuses(self): found_keys = [k for k in dir(a_state) if k.startswith("STATUS_")] found_vals = [getattr(a_state, k) for k in found_keys] for v in found_vals: - self.assertIn( - v, found_valid, f"Expected to find {v} in valid list {found_valid}" - ) + self.assertIn(v, found_valid, f"Expected to find {v} in valid list {found_valid}") for v in found_complete: self.assertIn( v, diff --git a/mephisto/abstractions/test/crowd_provider_tester.py b/mephisto/abstractions/test/crowd_provider_tester.py index 9a911a561..51ff0d805 100644 --- a/mephisto/abstractions/test/crowd_provider_tester.py +++ b/mephisto/abstractions/test/crowd_provider_tester.py @@ -94,9 +94,7 @@ def test_init_registers_datastore(self) -> None: for all crowd providers. """ ProviderClass = self.CrowdProviderClass - self.assertFalse( - self.db.has_datastore_for_provider(ProviderClass.PROVIDER_TYPE) - ) + self.assertFalse(self.db.has_datastore_for_provider(ProviderClass.PROVIDER_TYPE)) # Initialize the provider provider = ProviderClass(self.db) self.assertTrue(self.db.has_datastore_for_provider(ProviderClass.PROVIDER_TYPE)) @@ -107,9 +105,7 @@ def test_init_object_registers_datastore(self) -> None: for all crowd providers. """ ProviderClass = self.CrowdProviderClass - self.assertFalse( - self.db.has_datastore_for_provider(ProviderClass.PROVIDER_TYPE) - ) + self.assertFalse(self.db.has_datastore_for_provider(ProviderClass.PROVIDER_TYPE)) # Initialize the requester RequesterClass = ProviderClass.RequesterClass requester = RequesterClass.new(self.db, self.get_test_requester_name()) diff --git a/mephisto/abstractions/test/data_model_database_tester.py b/mephisto/abstractions/test/data_model_database_tester.py index 514abb7ac..8b1fba535 100644 --- a/mephisto/abstractions/test/data_model_database_tester.py +++ b/mephisto/abstractions/test/data_model_database_tester.py @@ -316,9 +316,7 @@ def test_update_task_failures(self) -> None: # But not after we've created a task run requester_name, requester_id = get_test_requester(db) init_params = json.dumps(OmegaConf.to_yaml(TaskRunArgs.get_mock_params())) - task_run_id = db.new_task_run( - task_id_2, requester_id, init_params, "mock", "mock" - ) + task_run_id = db.new_task_run(task_id_2, requester_id, init_params, "mock", "mock") with self.assertRaises(MephistoDBException): db.update_task(task_id_2, task_name=task_name_2) @@ -450,9 +448,7 @@ def test_task_run(self) -> None: # Check creation and retrieval of a task_run init_params = json.dumps(OmegaConf.to_yaml(TaskRunArgs.get_mock_params())) - task_run_id = db.new_task_run( - task_id, requester_id, init_params, "mock", "mock" - ) + task_run_id = db.new_task_run(task_id, requester_id, init_params, "mock", "mock") self.assertIsNotNone(task_run_id) self.assertTrue(isinstance(task_run_id, str)) task_run_row = db.get_task_run(task_run_id) @@ -889,9 +885,7 @@ def test_qualifications(self) -> None: qualifications = db.find_qualifications(qualification_name) self.assertEqual(len(qualifications), 0, "Qualification not remove") granted_quals = db.check_granted_qualifications() - self.assertEqual( - len(granted_quals), 0, "Cascade granted qualification not removed" - ) + self.assertEqual(len(granted_quals), 0, "Cascade granted qualification not removed") # cant retrieve the qualification directly anymore with self.assertRaises(EntryDoesNotExistException): @@ -907,9 +901,7 @@ def test_onboarding_agents(self) -> None: task = task_run.get_task() worker_name, worker_id = get_test_worker(db) - onboarding_agent_id = db.new_onboarding_agent( - worker_id, task.db_id, task_run_id, "mock" - ) + onboarding_agent_id = db.new_onboarding_agent(worker_id, task.db_id, task_run_id, "mock") self.assertIsNotNone(onboarding_agent_id) onboarding_agent = OnboardingAgent.get(db, onboarding_agent_id) diff --git a/mephisto/client/api.py b/mephisto/client/api.py index 84a90d157..9443c48af 100644 --- a/mephisto/client/api.py +++ b/mephisto/client/api.py @@ -78,8 +78,7 @@ def launch_options(): "architect_types": architect_types, "provider_types": provider_types, "blueprint_types": [ - {"name": bp, "rank": idx + 1} - for (idx, bp) in enumerate(blueprint_types) + {"name": bp, "rank": idx + 1} for (idx, bp) in enumerate(blueprint_types) ], } ) @@ -110,9 +109,7 @@ def view_unit(task_id): # TODO # MOCK - return jsonify( - {"id": task_id, "view_path": "https://google.com", "data": {"name": "me"}} - ) + return jsonify({"id": task_id, "view_path": "https://google.com", "data": {"name": "me"}}) @api.route("/task_runs/options") @@ -139,13 +136,9 @@ def requester_register(requester_type): parsed_options = parse_arg_dict(RequesterClass, options) except Exception as e: traceback.print_exc(file=sys.stdout) - return jsonify( - {"success": False, "msg": f"error in parsing arguments: {str(e)}"} - ) + return jsonify({"success": False, "msg": f"error in parsing arguments: {str(e)}"}) if "name" not in parsed_options: - return jsonify( - {"success": False, "msg": "No name was specified for the requester."} - ) + return jsonify({"success": False, "msg": "No name was specified for the requester."}) db = app.extensions["db"] requesters = db.find_requesters(requester_name=parsed_options["name"]) @@ -179,9 +172,7 @@ def get_submitted_data(): for task_run in task_runs: assignments += task_run.get_assignments() - assignments += [ - Assignment.get(db, assignment_id) for assignment_id in assignment_ids - ] + assignments += [Assignment.get(db, assignment_id) for assignment_id in assignment_ids] if len(statuses) == 0: statuses = [ diff --git a/mephisto/client/cli.py b/mephisto/client/cli.py index 59e295b5b..9095abb13 100644 --- a/mephisto/client/cli.py +++ b/mephisto/client/cli.py @@ -38,9 +38,7 @@ def cli(): click.rich_click.USE_RICH_MARKUP = True click.rich_click.SHOW_ARGUMENTS = True -click.rich_click.ERRORS_SUGGESTION = ( - "\nTry running the '--help' flag for more information." -) +click.rich_click.ERRORS_SUGGESTION = "\nTry running the '--help' flag for more information." click.rich_click.ERRORS_EPILOGUE = ( "To find out more, visit https://mephisto.ai/docs/guides/quickstart/\n" ) @@ -105,7 +103,7 @@ def config(identifier, value): @click.option("--db", "database_task_name", type=(str), default=None) @click.option("--all/--one-by-one", "all_data", default=False) @click.option("-d", "--debug", type=(bool), default=False) -@click.option("-h", "--host", type=(str), default='127.0.0.1') +@click.option("-h", "--host", type=(str), default="127.0.0.1") def review( review_app_dir, port, @@ -141,7 +139,7 @@ def review( raise click.BadParameter( f'The task name "{database_task_name}" did not exist in MephistoDB.\n\n' f'Perhaps you meant one of these? {", ".join(name_list)}\n\n' - f'Flag usage: mephisto review --db [task_name]\n' + f"Flag usage: mephisto review --db [task_name]\n" ) run( @@ -194,16 +192,12 @@ def list_requesters(): print("[red]No requesters found[/red]") -@cli.command( - "register", cls=RichCommand, context_settings={"ignore_unknown_options": True} -) +@cli.command("register", cls=RichCommand, context_settings={"ignore_unknown_options": True}) @click.argument("args", nargs=-1) def register_provider(args): """Register a requester with a crowd provider""" if len(args) == 0: - print( - "\n[red]Usage: mephisto register arg1=value arg2=value[/red]" - ) + print("\n[red]Usage: mephisto register arg1=value arg2=value[/red]") print("\n[b]Valid Providers[/b]") provider_text = """""" for provider in get_valid_provider_types(): @@ -272,9 +266,7 @@ def run_wut(args): get_wut_arguments(args) -@cli.command( - "scripts", cls=RichCommand, context_settings={"ignore_unknown_options": True} -) +@cli.command("scripts", cls=RichCommand, context_settings={"ignore_unknown_options": True}) @click.argument("script_type", required=False, nargs=1) @click.argument("script_name", required=False, nargs=1) def run_script(script_type, script_name): @@ -345,8 +337,7 @@ def print_non_markdown_list(items: List[str]): } if script_name is None or ( - script_name - not in script_type_to_scripts_data[script_type]["valid_script_names"] + script_name not in script_type_to_scripts_data[script_type]["valid_script_names"] ): print("") raise click.UsageError( @@ -359,9 +350,7 @@ def print_non_markdown_list(items: List[str]): script_type_to_scripts_data[script_type]["scripts"][script_name]() -@cli.command( - "metrics", cls=RichCommand, context_settings={"ignore_unknown_options": True} -) +@cli.command("metrics", cls=RichCommand, context_settings={"ignore_unknown_options": True}) @click.argument("args", nargs=-1) def metrics_cli(args): from mephisto.utils.metrics import ( @@ -376,9 +365,7 @@ def metrics_cli(args): if len(args) == 0 or args[0] not in ["install", "view", "cleanup"]: print("\n[red]Usage: mephisto metrics [/red]") metrics_table = create_table(["Property", "Value"], "Metrics Arguments") - metrics_table.add_row( - "install", f"Installs Prometheus and Grafana to {METRICS_DIR}" - ) + metrics_table.add_row("install", f"Installs Prometheus and Grafana to {METRICS_DIR}") metrics_table.add_row( "view", "Launches a Prometheus and Grafana server, and shuts down on exit", @@ -397,17 +384,13 @@ def metrics_cli(args): run_install_script() elif command == "view": if not metrics_are_installed(): - click.echo( - f"Metrics aren't installed! Use `mephisto metrics install` first." - ) + click.echo(f"Metrics aren't installed! Use `mephisto metrics install` first.") return click.echo(f"Servers launching - use ctrl-C to shutdown") launch_servers_and_wait() else: # command == 'cleanup': if not metrics_are_installed(): - click.echo( - f"Metrics aren't installed! Use `mephisto metrics install` first." - ) + click.echo(f"Metrics aren't installed! Use `mephisto metrics install` first.") return click.echo(f"Cleaning up existing servers if they exist") shutdown_prometheus_server() diff --git a/mephisto/client/cli_commands.py b/mephisto/client/cli_commands.py index 5fd84e0ee..7738d1f87 100644 --- a/mephisto/client/cli_commands.py +++ b/mephisto/client/cli_commands.py @@ -15,12 +15,8 @@ def get_wut_arguments(args): ) if len(args) == 0: - print( - "\n[red]Usage: mephisto wut [=] [...specific args to check][/red]" - ) - abstractions_table = create_table( - ["Abstraction", "Description"], "\n\n[b]Abstractions[/b]" - ) + print("\n[red]Usage: mephisto wut [=] [...specific args to check][/red]") + abstractions_table = create_table(["Abstraction", "Description"], "\n\n[b]Abstractions[/b]") abstractions_table.add_row( "blueprint", f"The blueprint contains all of the related code required to set up a task run. \nValid blueprints types are [b]{get_valid_blueprint_types()}[/b]", @@ -69,18 +65,12 @@ def get_wut_arguments(args): if abstraction == "blueprint": click.echo("The blueprint determines the task content.\n") valid_blueprints_text = """**Valid blueprints are:**""" - print_out_valid_options( - valid_blueprints_text, get_valid_blueprint_types() - ) + print_out_valid_options(valid_blueprints_text, get_valid_blueprint_types()) return elif abstraction == "architect": - click.echo( - "The architect determines the server where a task is hosted.\n" - ) + click.echo("The architect determines the server where a task is hosted.\n") valid_architect_text = """**Valid architects are:**""" - print_out_valid_options( - valid_architect_text, get_valid_architect_types() - ) + print_out_valid_options(valid_architect_text, get_valid_architect_types()) return elif abstraction == "requester": click.echo( @@ -88,14 +78,10 @@ def get_wut_arguments(args): "Use `mephisto requesters` to see registered requesters, and `mephisto register ` to register.\n" ) valid_requester_text = """**Valid requesters are:**""" - print_out_valid_options( - valid_requester_text, get_valid_provider_types() - ) + print_out_valid_options(valid_requester_text, get_valid_provider_types()) return elif abstraction == "provider": - click.echo( - "The crowd provider determines the source of the crowd workers.\n" - ) + click.echo("The crowd provider determines the source of the crowd workers.\n") valid_provider_text = """**Valid providers are:**""" print_out_valid_options(valid_provider_text, get_valid_provider_types()) return @@ -121,9 +107,7 @@ def get_wut_arguments(args): valid = get_valid_provider_types() elif abstraction == "requester": try: - target_class = get_crowd_provider_from_type( - abstract_value - ).RequesterClass + target_class = get_crowd_provider_from_type(abstract_value).RequesterClass except: valid = get_valid_provider_types() if valid is not None: @@ -151,15 +135,11 @@ def get_wut_arguments(args): first_arg_keys = list(checking_args[first_arg].keys()) args_table = create_table( first_arg_keys, - "\n[b]{abstraction} Arguments[/b]".format( - abstraction=abstraction.capitalize() - ), + "\n[b]{abstraction} Arguments[/b]".format(abstraction=abstraction.capitalize()), ) for arg in checking_args: if arg in argument_overrides: - checking_args[arg][argument_overrides[arg][0]] = argument_overrides[ - arg - ][1] + checking_args[arg][argument_overrides[arg][0]] = argument_overrides[arg][1] arg_values = list(checking_args[arg].values()) arg_values = [str(x) for x in arg_values] args_table.add_row(*arg_values) diff --git a/mephisto/client/full/server.py b/mephisto/client/full/server.py index 1b2980006..fd8eccef9 100644 --- a/mephisto/client/full/server.py +++ b/mephisto/client/full/server.py @@ -16,9 +16,7 @@ def get_app(): - app = Flask( - __name__, static_url_path="/static", static_folder="webapp/build/static" - ) + app = Flask(__name__, static_url_path="/static", static_folder="webapp/build/static") app.config.from_object(Config) app.register_blueprint(api, url_prefix="/api/v1") @@ -39,12 +37,8 @@ def index(path): @app.after_request def after_request(response): response.headers.add("Access-Control-Allow-Origin", "*") - response.headers.add( - "Access-Control-Allow-Headers", "Content-Type,Authorization" - ) - response.headers.add( - "Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS" - ) + response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization") + response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS") response.headers.add("Cache-Control", "no-store") return response diff --git a/mephisto/client/review/review_server.py b/mephisto/client/review/review_server.py index c24817afd..55b8265d8 100644 --- a/mephisto/client/review/review_server.py +++ b/mephisto/client/review/review_server.py @@ -43,9 +43,7 @@ def run( RESULT_SUCCESS = "SUCCESS" RESULT_ERROR = "ERROR" - DataQueryResult = collections.namedtuple( - "DataQueryResult", ["data_list", "total_pages"] - ) + DataQueryResult = collections.namedtuple("DataQueryResult", ["data_list", "total_pages"]) if not debug or output == "": # disable noisy logging of flask, https://stackoverflow.com/a/18379764 @@ -138,10 +136,7 @@ def consume_all_data(page, results_per_page=RESULTS_PER_PAGE_DEFAULT, filters=No # If differnce in time since the last update to the data list is over 5 minutes, update list again # This can only be done for usage with mephistoDB as standard input is exhausted when originally creating the list now = datetime.now() - if ( - USE_TIMEOUT - and (now - datalist_update_time).total_seconds() > TIMEOUT_IN_SECONDS - ): + if USE_TIMEOUT and (now - datalist_update_time).total_seconds() > TIMEOUT_IN_SECONDS: refresh_all_list_data() filtered_data_list = all_data_list @@ -158,9 +153,7 @@ def consume_all_data(page, results_per_page=RESULTS_PER_PAGE_DEFAULT, filters=No if first_index > list_len - 1: filtered_data_list = [] else: - results_per_page = ( - min(first_index + results_per_page, list_len) - first_index - ) + results_per_page = min(first_index + results_per_page, list_len) - first_index if results_per_page < 0: filtered_data_list = [] else: @@ -202,9 +195,7 @@ def data(): raise RuntimeError("Not running with the Werkzeug Server") func() - return jsonify( - {"finished": finished, "data": current_data if not finished else None} - ) + return jsonify({"finished": finished, "data": current_data if not finished else None}) @app.route("/submit_current_task", methods=["GET", "POST"]) def next_task(): @@ -223,9 +214,7 @@ def next_task(): } ) result = ( - request.get_json(force=True) - if request.method == "POST" - else request.args.get("result") + request.get_json(force=True) if request.method == "POST" else request.args.get("result") ) if output == "": @@ -255,9 +244,7 @@ def task_data_by_id(id): if all_data: list_len = len(all_data_list) if id is None or id < 0 or id >= list_len: - return jsonify( - {"error": f"Data with ID: {id} does not exist", "mode": MODE} - ) + return jsonify({"error": f"Data with ID: {id} does not exist", "mode": MODE}) return jsonify({"data": all_data_list[id], "mode": MODE}) else: if id is None or id != counter - 1: @@ -294,9 +281,7 @@ def task_data_by_id(id): if not all_data: ready_for_next.set() time.sleep(0) - return jsonify( - {"result": RESULT_SUCCESS, "finished": finished, "mode": MODE} - ) + return jsonify({"result": RESULT_SUCCESS, "finished": finished, "mode": MODE}) @app.route("/data") def all_task_data(): @@ -348,12 +333,8 @@ def index(id): @app.after_request def after_request(response): response.headers.add("Access-Control-Allow-Origin", "*") - response.headers.add( - "Access-Control-Allow-Headers", "Content-Type,Authorization" - ) - response.headers.add( - "Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS" - ) + response.headers.add("Access-Control-Allow-Headers", "Content-Type,Authorization") + response.headers.add("Access-Control-Allow-Methods", "GET,PUT,POST,DELETE,OPTIONS") response.headers.add("Cache-Control", "no-store") return response @@ -379,7 +360,7 @@ def after_request(response): thread = threading.Thread(target=consume_data, name="review-server-thread") thread.start() - host = host or '127.0.0.1' + host = host or "127.0.0.1" print(f"Running on http://{host}:{port}/ (Press CTRL+C to quit)") sys.stdout.flush() app.run(debug=False, port=port, host=host) diff --git a/mephisto/data_model/agent.py b/mephisto/data_model/agent.py index c43b597bf..99cd9242e 100644 --- a/mephisto/data_model/agent.py +++ b/mephisto/data_model/agent.py @@ -132,9 +132,7 @@ def set_live_run(self, live_run: "LiveTaskRun") -> None: def get_live_run(self) -> "LiveTaskRun": """Return the associated live run for this agent. Throw if not set""" if self._associated_live_run is None: - raise AssertionError( - "Should not be getting the live run, not set for given agent" - ) + raise AssertionError("Should not be getting the live run, not set for given agent") return self._associated_live_run def agent_in_active_run(self) -> bool: @@ -188,9 +186,7 @@ def observe(self, live_update: "Dict[str, Any]") -> None: live_run = self.get_live_run() live_run.client_io.send_live_update(self.get_agent_id(), live_update) - def get_live_update( - self, timeout: Optional[int] = None - ) -> Optional[Dict[str, Any]]: + def get_live_update(self, timeout: Optional[int] = None) -> Optional[Dict[str, Any]]: """ Request information from the Agent's frontend. If non-blocking, (timeout is None) should return None if no actions are ready @@ -213,9 +209,7 @@ def get_live_update( raise AgentReturnedError(self.db_id) self.update_status(AgentState.STATUS_TIMEOUT) raise AgentTimeoutError(timeout, self.db_id) - assert ( - not self.pending_actions.empty() - ), "has_live_update released without an action!" + assert not self.pending_actions.empty(), "has_live_update released without an action!" act = self.pending_actions.get() @@ -289,15 +283,11 @@ def handle_metadata_submit(self, data: Dict[str, Any]) -> None: "accepted": False, } if self.state.metadata.tips is None: - self.state.update_metadata( - property_name="tips", property_value=[tip_to_add] - ) + self.state.update_metadata(property_name="tips", property_value=[tip_to_add]) else: copy_of_tips = self.state.metadata.tips.copy() copy_of_tips.append(tip_to_add) - self.state.update_metadata( - property_name="tips", property_value=copy_of_tips - ) + self.state.update_metadata(property_name="tips", property_value=copy_of_tips) elif "feedback" in data: questions_and_answers = data["feedback"]["data"] @@ -370,9 +360,7 @@ def get_status(self) -> str: raise NotImplementedError -class Agent( - _AgentBase, MephistoDataModelComponentMixin, metaclass=MephistoDBBackedABCMeta -): +class Agent(_AgentBase, MephistoDataModelComponentMixin, metaclass=MephistoDBBackedABCMeta): """ This class encompasses a worker as they are working on an individual assignment. It maintains details for the current task at hand such as start and end time, @@ -431,9 +419,7 @@ def __new__( if row is None: row = db.get_agent(db_id) assert row is not None, f"Given db_id {db_id} did not exist in given db" - correct_class = get_crowd_provider_from_type( - row["provider_type"] - ).AgentClass + correct_class = get_crowd_provider_from_type(row["provider_type"]).AgentClass return super().__new__(correct_class) else: # We are constructing another instance directly @@ -487,9 +473,7 @@ def update_status(self, new_status: str) -> None: self.db_status = new_status if self.agent_in_active_run(): live_run = self.get_live_run() - live_run.loop_wrap.execute_coro( - live_run.worker_pool.push_status_update(self) - ) + live_run.loop_wrap.execute_coro(live_run.worker_pool.push_status_update(self)) if new_status in [ AgentState.STATUS_RETURNED, AgentState.STATUS_DISCONNECT, @@ -508,10 +492,7 @@ def update_status(self, new_status: str) -> None: # Metrics changes ACTIVE_AGENT_STATUSES.labels(status=old_status, agent_type="main").dec() ACTIVE_AGENT_STATUSES.labels(status=new_status, agent_type="main").inc() - if ( - old_status not in AgentState.complete() - and new_status in AgentState.complete() - ): + if old_status not in AgentState.complete() and new_status in AgentState.complete(): ACTIVE_WORKERS.labels(worker_id=self.worker_id, agent_type="main").dec() @staticmethod @@ -532,9 +513,7 @@ def _register_agent( provider_type, ) a = Agent.get(db, db_id) - ACTIVE_AGENT_STATUSES.labels( - status=AgentState.STATUS_NONE, agent_type="main" - ).inc() + ACTIVE_AGENT_STATUSES.labels(status=AgentState.STATUS_NONE, agent_type="main").inc() ACTIVE_WORKERS.labels(worker_id=worker.db_id, agent_type="main").inc() logger.debug(f"Registered new agent {a} for {unit}.") a.update_status(AgentState.STATUS_ACCEPTED) @@ -572,9 +551,7 @@ def get_status(self) -> str: self.has_live_update.set() if self.agent_in_active_run(): live_run = self.get_live_run() - live_run.loop_wrap.execute_coro( - live_run.worker_pool.push_status_update(self) - ) + live_run.loop_wrap.execute_coro(live_run.worker_pool.push_status_update(self)) self.db_status = row["status"] return self.db_status @@ -699,9 +676,7 @@ def update_status(self, new_status: str) -> None: AgentState.STATUS_REJECTED, ]: live_run = self.get_live_run() - live_run.loop_wrap.execute_coro( - live_run.worker_pool.push_status_update(self) - ) + live_run.loop_wrap.execute_coro(live_run.worker_pool.push_status_update(self)) if new_status in [AgentState.STATUS_RETURNED, AgentState.STATUS_DISCONNECT]: # Disconnect statuses should free any pending acts self.has_live_update.set() @@ -710,13 +685,8 @@ def update_status(self, new_status: str) -> None: # Metrics changes ACTIVE_AGENT_STATUSES.labels(status=old_status, agent_type="onboarding").dec() ACTIVE_AGENT_STATUSES.labels(status=new_status, agent_type="onboarding").inc() - if ( - old_status not in AgentState.complete() - and new_status in AgentState.complete() - ): - ACTIVE_WORKERS.labels( - worker_id=self.worker_id, agent_type="onboarding" - ).dec() + if old_status not in AgentState.complete() and new_status in AgentState.complete(): + ACTIVE_WORKERS.labels(worker_id=self.worker_id, agent_type="onboarding").dec() def get_status(self) -> str: """Get the status of this agent in their work on their unit""" @@ -750,9 +720,7 @@ def new(db: "MephistoDB", worker: Worker, task_run: "TaskRun") -> "OnboardingAge worker.db_id, task_run.task_id, task_run.db_id, task_run.task_type ) a = OnboardingAgent.get(db, db_id) - ACTIVE_AGENT_STATUSES.labels( - status=AgentState.STATUS_NONE, agent_type="onboarding" - ).inc() + ACTIVE_AGENT_STATUSES.labels(status=AgentState.STATUS_NONE, agent_type="onboarding").inc() ACTIVE_WORKERS.labels(worker_id=worker.db_id, agent_type="onboarding").inc() logger.debug(f"Registered new {a} for worker {worker}.") return a diff --git a/mephisto/data_model/assignment.py b/mephisto/data_model/assignment.py index 9e0d88eed..d6290104d 100644 --- a/mephisto/data_model/assignment.py +++ b/mephisto/data_model/assignment.py @@ -43,9 +43,7 @@ def dumpJSON(self, fp: IO[str]): @staticmethod def loadFromJSON(fp: IO[str]): as_dict = json.load(fp) - return InitializationData( - shared=as_dict["shared"], unit_data=as_dict["unit_data"] - ) + return InitializationData(shared=as_dict["shared"], unit_data=as_dict["unit_data"]) class Assignment(MephistoDataModelComponentMixin, metaclass=MephistoDBBackedMeta): @@ -135,12 +133,7 @@ def get_status(self) -> str: # If any are still assigned, consider the whole thing assigned return AssignmentState.ASSIGNED - if all( - [ - s in [AssignmentState.ACCEPTED, AssignmentState.REJECTED] - for s in statuses - ] - ): + if all([s in [AssignmentState.ACCEPTED, AssignmentState.REJECTED] for s in statuses]): return AssignmentState.MIXED if all([s in AssignmentState.final_agent() for s in statuses]): @@ -183,9 +176,7 @@ def get_units(self, status: Optional[str] = None) -> List["Unit"]: Get units for this assignment, optionally constrained by the specific status. """ - assert ( - status is None or status in AssignmentState.valid_unit() - ), "Invalid assignment status" + assert status is None or status in AssignmentState.valid_unit(), "Invalid assignment status" units = self.db.find_units(assignment_id=self.db_id) if status is not None: units = [u for u in units if u.get_status() == status] @@ -238,9 +229,7 @@ def new( assign_dir = os.path.join(run_dir, db_id) os.makedirs(assign_dir) if assignment_data is not None: - with open( - os.path.join(assign_dir, ASSIGNMENT_DATA_FILE), "w+" - ) as json_file: + with open(os.path.join(assign_dir, ASSIGNMENT_DATA_FILE), "w+") as json_file: json.dump(assignment_data, json_file) assignment = Assignment.get(db, db_id) logger.debug(f"{assignment} created for {task_run}") diff --git a/mephisto/data_model/requester.py b/mephisto/data_model/requester.py index cb6b8c013..9368c7616 100644 --- a/mephisto/data_model/requester.py +++ b/mephisto/data_model/requester.py @@ -88,9 +88,7 @@ def __new__( if row is None: row = db.get_requester(db_id) assert row is not None, f"Given db_id {db_id} did not exist in given db" - correct_class = get_crowd_provider_from_type( - row["provider_type"] - ).RequesterClass + correct_class = get_crowd_provider_from_type(row["provider_type"]).RequesterClass return super().__new__(correct_class) else: # We are constructing another instance directly @@ -124,9 +122,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.db_id})" @staticmethod - def _register_requester( - db: "MephistoDB", requester_id: str, provider_type: str - ) -> "Requester": + def _register_requester(db: "MephistoDB", requester_id: str, provider_type: str) -> "Requester": """ Create an entry for this requester in the database """ diff --git a/mephisto/data_model/task_run.py b/mephisto/data_model/task_run.py index 17f14a3ae..476cc917b 100644 --- a/mephisto/data_model/task_run.py +++ b/mephisto/data_model/task_run.py @@ -237,9 +237,7 @@ def get_valid_units_for_worker(self, worker: "Worker") -> List["Unit"]: currently_active = len(current_units) if config.allowed_concurrent != 0: if currently_active >= config.allowed_concurrent: - logger.debug( - f"{worker} at maximum concurrent units {currently_active}" - ) + logger.debug(f"{worker} at maximum concurrent units {currently_active}") return [] # currently at the maximum number of concurrent units if config.maximum_units_per_worker != 0: completed_types = AssignmentState.completed() @@ -250,10 +248,7 @@ def get_valid_units_for_worker(self, worker: "Worker") -> List["Unit"]: currently_completed = len( [u for u in related_units if u.db_status in completed_types] ) - if ( - currently_active + currently_completed - >= config.maximum_units_per_worker - ): + if currently_active + currently_completed >= config.maximum_units_per_worker: logger.debug( f"{worker} at maximum units {currently_active}, {currently_completed}" ) @@ -278,19 +273,13 @@ def get_valid_units_for_worker(self, worker: "Worker") -> List["Unit"]: # Can use db_status directly rather than polling in the critical path, as in # the worst case we miss the transition from an active to launched unit valid_units = [ - u - for u in units - if u.db_status == AssignmentState.LAUNCHED and u.unit_index >= 0 + u for u in units if u.db_status == AssignmentState.LAUNCHED and u.unit_index >= 0 ] logger.debug(f"Found {len(valid_units)} available units") # Should load cached blueprint for SharedTaskState blueprint = self.get_blueprint() - ret_units = [ - u - for u in valid_units - if blueprint.shared_state.worker_can_do_unit(worker, u) - ] + ret_units = [u for u in valid_units if blueprint.shared_state.worker_can_do_unit(worker, u)] logger.debug(f"This worker is qualified for {len(ret_units)} unit.") logger.debug(f"Found {ret_units[:3]} for {worker}.") @@ -387,9 +376,7 @@ def get_assignments(self, status: Optional[str] = None) -> List["Assignment"]: Get assignments for this run, optionally filtering by their current status """ - assert ( - status is None or status in AssignmentState.valid() - ), "Invalid assignment status" + assert status is None or status in AssignmentState.valid(), "Invalid assignment status" assignments = self.db.find_assignments(task_run_id=self.db_id) if status is not None: assignments = [a for a in assignments if a.get_status() == status] @@ -402,9 +389,7 @@ def get_assignment_statuses(self) -> Dict[str, int]: assigns = self.get_assignments() assigns_with_status = [(x, x.get_status()) for x in assigns] return { - status: len( - [x for x, had_status in assigns_with_status if had_status == status] - ) + status: len([x for x, had_status in assigns_with_status if had_status == status]) for status in AssignmentState.valid() } @@ -480,9 +465,7 @@ def __repr__(self) -> str: return f"TaskRun({self.db_id})" @staticmethod - def new( - db: "MephistoDB", task: "Task", requester: Requester, param_string: str - ) -> "TaskRun": + def new(db: "MephistoDB", task: "Task", requester: Requester, param_string: str) -> "TaskRun": """ Create a new run for the given task with the given params """ diff --git a/mephisto/data_model/unit.py b/mephisto/data_model/unit.py index 4852f1022..abd1d7d94 100644 --- a/mephisto/data_model/unit.py +++ b/mephisto/data_model/unit.py @@ -181,9 +181,7 @@ def set_db_status(self, status: str) -> None: def _mark_agent_assignment(self) -> None: """Special helper to mark the transition from LAUNCHED to ASSIGNED""" - assert ( - self.db_status == AssignmentState.LAUNCHED - ), "can only mark LAUNCHED units" + assert self.db_status == AssignmentState.LAUNCHED, "can only mark LAUNCHED units" ACTIVE_UNIT_STATUSES.labels( status=AssignmentState.LAUNCHED, unit_type=INDEX_TO_TYPE_MAP[self.unit_index], @@ -394,9 +392,7 @@ def is_expired(self) -> bool: raise NotImplementedError() @staticmethod - def new( - db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float - ) -> "Unit": + def new(db: "MephistoDB", assignment: "Assignment", index: int, pay_amount: float) -> "Unit": """ Create a Unit for the given assignment diff --git a/mephisto/data_model/worker.py b/mephisto/data_model/worker.py index 9bc9b5fdf..939019f16 100644 --- a/mephisto/data_model/worker.py +++ b/mephisto/data_model/worker.py @@ -108,9 +108,7 @@ def get_agents(self, status: Optional[str] = None) -> List["Agent"]: return self.db.find_agents(worker_id=self.db_id, status=status) @staticmethod - def _register_worker( - db: "MephistoDB", worker_name: str, provider_type: str - ) -> "Worker": + def _register_worker(db: "MephistoDB", worker_name: str, provider_type: str) -> "Worker": """ Create an entry for this worker in the database """ @@ -120,9 +118,7 @@ def _register_worker( return worker @classmethod - def new_from_provider_data( - cls, db: "MephistoDB", creation_data: Dict[str, Any] - ) -> "Worker": + def new_from_provider_data(cls, db: "MephistoDB", creation_data: Dict[str, Any]) -> "Worker": """ Given the parameters passed through wrap_crowd_source.js, construct a new worker @@ -194,9 +190,7 @@ def revoke_qualification(self, qualification_name) -> bool: return False return True - def grant_qualification( - self, qualification_name: str, value: int = 1, skip_crowd=False - ): + def grant_qualification(self, qualification_name: str, value: int = 1, skip_crowd=False): """ Grant a positive or negative qualification to this worker @@ -205,13 +199,9 @@ def grant_qualification( """ found_qualifications = self.db.find_qualifications(qualification_name) if len(found_qualifications) == 0: - raise Exception( - f"No qualification by the name {qualification_name} found in the db" - ) + raise Exception(f"No qualification by the name {qualification_name} found in the db") - logger.debug( - f"Granting worker {self} qualification {qualification_name}: {value}" - ) + logger.debug(f"Granting worker {self} qualification {qualification_name}: {value}") qualification = found_qualifications[0] self.db.grant_qualification(qualification.db_id, self.db_id, value=value) if not skip_crowd: @@ -230,9 +220,7 @@ def __repr__(self) -> str: # Children classes can implement the following methods - def grant_crowd_qualification( - self, qualification_name: str, value: int = 1 - ) -> None: + def grant_crowd_qualification(self, qualification_name: str, value: int = 1) -> None: """ Grant a qualification by the given name to this worker diff --git a/mephisto/operations/client_io_handler.py b/mephisto/operations/client_io_handler.py index 15adbc393..3e135f453 100644 --- a/mephisto/operations/client_io_handler.py +++ b/mephisto/operations/client_io_handler.py @@ -119,27 +119,23 @@ def log_metrics_for_packet(self, packet: "Packet") -> None: if client_timestamp is None: client_timestamp = router_incoming_timestamp client_to_router = max(0, router_incoming_timestamp - client_timestamp) - router_processing = max( - 0, router_outgoing_timestamp - router_incoming_timestamp - ) + router_processing = max(0, router_outgoing_timestamp - router_incoming_timestamp) router_to_server = max(0, server_timestamp - router_outgoing_timestamp) server_processing = max(0, response_timestamp - server_timestamp) e2e_time = max(0, response_timestamp - client_timestamp) - E2E_PACKET_LATENCY.labels( - packet_type=packet.type, stage="client_to_router" - ).observe(client_to_router) - E2E_PACKET_LATENCY.labels( - packet_type=packet.type, stage="router_processing" - ).observe(router_processing) - E2E_PACKET_LATENCY.labels( - packet_type=packet.type, stage="router_to_server" - ).observe(router_to_server) - E2E_PACKET_LATENCY.labels( - packet_type=packet.type, stage="server_processing" - ).observe(server_processing) - E2E_PACKET_LATENCY.labels(packet_type=packet.type, stage="e2e_time").observe( - e2e_time + E2E_PACKET_LATENCY.labels(packet_type=packet.type, stage="client_to_router").observe( + client_to_router + ) + E2E_PACKET_LATENCY.labels(packet_type=packet.type, stage="router_processing").observe( + router_processing + ) + E2E_PACKET_LATENCY.labels(packet_type=packet.type, stage="router_to_server").observe( + router_to_server ) + E2E_PACKET_LATENCY.labels(packet_type=packet.type, stage="server_processing").observe( + server_processing + ) + E2E_PACKET_LATENCY.labels(packet_type=packet.type, stage="e2e_time").observe(e2e_time) def register_run(self, live_run: "LiveTaskRun") -> None: """Register a live run for this io handler""" @@ -165,9 +161,7 @@ def _on_catastrophic_disconnect(self, channel_id: str) -> None: live_run = self.get_live_run() live_run.force_shutdown = True - async def __on_channel_message_internal( - self, channel_id: str, packet: Packet - ) -> None: + async def __on_channel_message_internal(self, channel_id: str, packet: Packet) -> None: """Incoming message handler defers to the internal handler""" try: self._on_message(packet, channel_id) @@ -308,9 +302,7 @@ def _on_submit_onboarding(self, packet: Packet, channel_id: str) -> None: # On resubmit, ensure that the client has the same status agent = live_run.worker_pool.final_onboardings.get(onboarding_id) if agent is not None: - live_run.loop_wrap.execute_coro( - live_run.worker_pool.push_status_update(agent) - ) + live_run.loop_wrap.execute_coro(live_run.worker_pool.push_status_update(agent)) return agent = live_run.worker_pool.get_agent_for_id(onboarding_id) assert agent is not None, f"Could not find given agent by id {onboarding_id}" diff --git a/mephisto/operations/config_handler.py b/mephisto/operations/config_handler.py index 6d545f192..a7494b089 100644 --- a/mephisto/operations/config_handler.py +++ b/mephisto/operations/config_handler.py @@ -46,9 +46,7 @@ def init_config() -> None: with open(OLD_DATA_CONFIG_LOC, "r") as data_dir_file: loaded_data_dir = data_dir_file.read().strip() with open(DEFAULT_CONFIG_FILE, "w") as config_file: - config_file.write( - yaml.dump({CORE_SECTION: {DATA_STORAGE_KEY: loaded_data_dir}}) - ) + config_file.write(yaml.dump({CORE_SECTION: {DATA_STORAGE_KEY: loaded_data_dir}})) print(f"Removing DATA_LOC configuration file from {OLD_DATA_CONFIG_LOC}") os.unlink(OLD_DATA_CONFIG_LOC) elif not os.path.exists(DEFAULT_CONFIG_FILE): diff --git a/mephisto/operations/datatypes.py b/mephisto/operations/datatypes.py index 441032dc0..2717d46b1 100644 --- a/mephisto/operations/datatypes.py +++ b/mephisto/operations/datatypes.py @@ -80,9 +80,7 @@ def shutdown(self): class WorkerFailureReasons: NOT_QUALIFIED = "You are not currently qualified to work on this task..." - NO_AVAILABLE_UNITS = ( - "There is currently no available work, please try again later..." - ) + NO_AVAILABLE_UNITS = "There is currently no available work, please try again later..." TOO_MANY_CONCURRENT = "You are currently working on too many tasks concurrently to accept another, please finish your current work." MAX_FOR_TASK = "You have already completed the maximum amount of tasks the requester has set for this task." TASK_MISSING = "You appear to have already completed this task, or have disconnected long enough for your session to clear..." diff --git a/mephisto/operations/operator.py b/mephisto/operations/operator.py index 7fb3b8c2b..f02bacecf 100644 --- a/mephisto/operations/operator.py +++ b/mephisto/operations/operator.py @@ -115,9 +115,7 @@ def _get_requester_and_provider_from_config(self, run_config: DictConfig): if run_config.provider.requester_name == "MOCK_REQUESTER": requesters = [get_mock_requester(self.db)] else: - raise EntryDoesNotExistException( - f"No requester found with name {requester_name}" - ) + raise EntryDoesNotExistException(f"No requester found with name {requester_name}") requester = requesters[0] requester_id = requester.db_id provider_type = requester.provider_type @@ -146,13 +144,9 @@ def _create_live_task_run( # prepare the architect build_dir = os.path.join(task_run.get_run_dir(), "build") os.makedirs(build_dir, exist_ok=True) - architect = architect_class( - self.db, run_config, shared_state, task_run, build_dir - ) + architect = architect_class(self.db, run_config, shared_state, task_run, build_dir) # Create the backend runner - task_runner = blueprint_class.TaskRunnerClass( - task_run, run_config, shared_state - ) + task_runner = blueprint_class.TaskRunnerClass(task_run, run_config, shared_state) # Small hack for auto appending block qualification # TODO(OWN) we can use blueprint.mro() to discover BlueprintMixins and extract from there @@ -223,9 +217,7 @@ def launch_task_run_or_die( """ set_mephisto_log_level(level=run_config.get("log_level", "info")) - requester, provider_type = self._get_requester_and_provider_from_config( - run_config - ) + requester, provider_type = self._get_requester_and_provider_from_config(run_config) # Next get the abstraction classes, and run validation # before anything is actually created in the database @@ -299,9 +291,7 @@ def launch_task_run_or_die( live_run.client_io.launch_channels() except (KeyboardInterrupt, Exception) as e: - logger.error( - "Encountered error while launching run, shutting down", exc_info=True - ) + logger.error("Encountered error while launching run, shutting down", exc_info=True) try: live_run.architect.shutdown() except (KeyboardInterrupt, Exception) as architect_exception: @@ -338,9 +328,7 @@ async def _track_and_kill_runs(self): tracked_run.force_shutdown = True if not tracked_run.force_shutdown: task_run = tracked_run.task_run - task_run.update_completion_progress( - task_launcher=tracked_run.task_launcher - ) + task_run.update_completion_progress(task_launcher=tracked_run.task_launcher) if not task_run.get_is_completed(): continue if tracked_run.task_launcher.finished_generators is False: @@ -418,9 +406,7 @@ def shutdown(self, skip_input=True): try: tracked_run.task_launcher.shutdown() except (KeyboardInterrupt, SystemExit) as e: - logger.info( - f"Skipping waiting for launcher threads to join on task run {run_id}." - ) + logger.info(f"Skipping waiting for launcher threads to join on task run {run_id}.") def cant_cancel_expirations(sig, frame): logger.warn( @@ -507,13 +493,9 @@ def launch_task_run( Wrapper around validate_and_run_config_or_die that prints errors on failure, rather than throwing. Generally for use in scripts. """ - assert ( - not self.is_shutdown - ), "Cannot run a config on a shutdown operator. Create a new one." + assert not self.is_shutdown, "Cannot run a config on a shutdown operator. Create a new one." try: - return self.launch_task_run_or_die( - run_config=run_config, shared_state=shared_state - ) + return self.launch_task_run_or_die(run_config=run_config, shared_state=shared_state) except (KeyboardInterrupt, Exception) as e: logger.error("Ran into error while launching run: ", exc_info=True) return None @@ -574,9 +556,7 @@ def trigger_shutdown(): self._event_loop.call_later(timeout_time, trigger_shutdown) self._event_loop.run_forever() - def wait_for_runs_then_shutdown( - self, skip_input=False, log_rate: Optional[int] = None - ) -> None: + def wait_for_runs_then_shutdown(self, skip_input=False, log_rate: Optional[int] = None) -> None: """ Wait for task_runs to complete, and then shutdown. @@ -590,7 +570,7 @@ def wait_for_runs_then_shutdown( try: self._event_loop.run_forever() except Exception as e: - logger.exception('Encountered error during task run') + logger.exception("Encountered error during task run") except (KeyboardInterrupt, SystemExit) as e: logger.exception( "Cleaning up after keyboard interrupt, please " diff --git a/mephisto/operations/registry.py b/mephisto/operations/registry.py index 997d7cbf8..7d1d57e9b 100644 --- a/mephisto/operations/registry.py +++ b/mephisto/operations/registry.py @@ -51,9 +51,7 @@ def register_cls( f"Provided class {base_class} not a child of one of the mephisto " "abstractions, expected one of Blueprint, Architect, or CrowdProvider." ) - register_abstraction_config( - name=name, node=base_class.ArgsClass, abstraction_type=type_key - ) + register_abstraction_config(name=name, node=base_class.ArgsClass, abstraction_type=type_key) return base_class return register_cls @@ -93,22 +91,16 @@ def fill_registries(): ) # Import Mephisto Architects - architect_root = os.path.join( - get_root_dir(), "mephisto", "abstractions", "architects" - ) + architect_root = os.path.join(get_root_dir(), "mephisto", "abstractions", "architects") for filename in os.listdir(architect_root): if filename.endswith("architect.py"): architect_name = filename[: filename.find(".py")] - importlib.import_module( - f"mephisto.abstractions.architects.{architect_name}" - ) + importlib.import_module(f"mephisto.abstractions.architects.{architect_name}") # After imports are recursive, manage this more cleanly importlib.import_module("mephisto.abstractions.architects.ec2.ec2_architect") # Import Mephisto Blueprints - blueprint_root = os.path.join( - get_root_dir(), "mephisto", "abstractions", "blueprints" - ) + blueprint_root = os.path.join(get_root_dir(), "mephisto", "abstractions", "blueprints") for dir_name in os.listdir(blueprint_root): blueprint_dir = os.path.join(blueprint_root, dir_name) if not os.path.isdir(blueprint_dir): @@ -126,9 +118,7 @@ def get_crowd_provider_from_type(provider_type: str) -> Type["CrowdProvider"]: if provider_type in PROVIDERS: return PROVIDERS[provider_type] else: - raise NotImplementedError( - f"Missing provider type {provider_type}, is it registered?" - ) + raise NotImplementedError(f"Missing provider type {provider_type}, is it registered?") def get_blueprint_from_type(task_type: str) -> Type["Blueprint"]: @@ -136,9 +126,7 @@ def get_blueprint_from_type(task_type: str) -> Type["Blueprint"]: if task_type in BLUEPRINTS: return BLUEPRINTS[task_type] else: - raise NotImplementedError( - f"Missing blueprint type {task_type}, is it registered?" - ) + raise NotImplementedError(f"Missing blueprint type {task_type}, is it registered?") def get_architect_from_type(architect_type: str) -> Type["Architect"]: @@ -146,9 +134,7 @@ def get_architect_from_type(architect_type: str) -> Type["Architect"]: if architect_type in ARCHITECTS: return ARCHITECTS[architect_type] else: - raise NotImplementedError( - f"Missing architect type {architect_type}, is it registered?" - ) + raise NotImplementedError(f"Missing architect type {architect_type}, is it registered?") def get_valid_provider_types() -> List[str]: diff --git a/mephisto/operations/task_launcher.py b/mephisto/operations/task_launcher.py index 178ca39ff..0a9df576d 100644 --- a/mephisto/operations/task_launcher.py +++ b/mephisto/operations/task_launcher.py @@ -104,9 +104,7 @@ def _create_single_assignment(self, assignment_data) -> None: self.assignments.append(assignment) unit_count = len(assignment_data.unit_data) for unit_idx in range(unit_count): - unit = self.UnitClass.new( - self.db, assignment, unit_idx, task_args.task_reward - ) + unit = self.UnitClass.new(self.db, assignment, unit_idx, task_args.task_reward) self.units.append(unit) with self.unlaunched_units_access_condition: self.unlaunched_units[unit.db_id] = unit @@ -147,10 +145,7 @@ def generate_units(self): units_id_to_remove = [] for db_id, unit in self.launched_units.items(): status = unit.get_status() - if ( - status != AssignmentState.LAUNCHED - and status != AssignmentState.ASSIGNED - ): + if status != AssignmentState.LAUNCHED and status != AssignmentState.ASSIGNED: units_id_to_remove.append(db_id) for db_id in units_id_to_remove: self.launched_units.pop(db_id) @@ -204,9 +199,7 @@ def launch_units(self, url: str) -> None: ) self.units_thread.start() - def launch_evaluation_unit( - self, unit_data: Dict[str, Any], unit_type_index: int - ) -> "Unit": + def launch_evaluation_unit(self, unit_data: Dict[str, Any], unit_type_index: int) -> "Unit": """Launch a specific evaluation unit, used for quality control""" assert ( self.launch_url is not None diff --git a/mephisto/operations/worker_pool.py b/mephisto/operations/worker_pool.py index 6022d3797..359ce975f 100644 --- a/mephisto/operations/worker_pool.py +++ b/mephisto/operations/worker_pool.py @@ -125,24 +125,18 @@ def get_live_run(self) -> "LiveTaskRun": assert live_run is not None, "Live run must be registered to use this" return live_run - def get_agent_for_id( - self, agent_id: str - ) -> Optional[Union["Agent", "OnboardingAgent"]]: + def get_agent_for_id(self, agent_id: str) -> Optional[Union["Agent", "OnboardingAgent"]]: """Temporary method to get an agent, while API is figured out""" if agent_id in self.agents: return self.agents[agent_id] elif agent_id in self.onboarding_agents: return self.onboarding_agents[agent_id] elif agent_id in self.final_onboardings: - logger.debug( - f"Found agent id {agent_id} in final_onboardings for get_agent_for_id" - ) + logger.debug(f"Found agent id {agent_id} in final_onboardings for get_agent_for_id") return self.final_onboardings[agent_id] return None - async def register_worker( - self, crowd_data: Dict[str, Any], request_id: str - ) -> None: + async def register_worker(self, crowd_data: Dict[str, Any], request_id: str) -> None: """ First process the worker registration, then hand off for registering an agent @@ -177,9 +171,7 @@ async def register_worker( AGENT_DETAILS_COUNT.labels(response="not_qualified").inc() live_run.client_io.enqueue_agent_details( request_id, - AgentDetails( - failure_reason=WorkerFailureReasons.NOT_QUALIFIED - ).to_dict(), + AgentDetails(failure_reason=WorkerFailureReasons.NOT_QUALIFIED).to_dict(), ) else: await self.register_agent(crowd_data, worker, request_id) @@ -197,9 +189,7 @@ async def _assign_unit_to_agent( task_runner = live_run.task_runner crowd_provider = live_run.provider - logger.debug( - f"Worker {worker.db_id} is being assigned one of {len(units)} units." - ) + logger.debug(f"Worker {worker.db_id} is being assigned one of {len(units)} units.") reserved_unit = None while len(units) > 0 and reserved_unit is None: @@ -234,9 +224,7 @@ async def _assign_unit_to_agent( logger.debug(f"Created agent {agent}, {agent.db_id}.") # TODO(#649) this is IO bound - with EXTERNAL_FUNCTION_LATENCY.labels( - function="get_init_data_for_agent" - ).time(): + with EXTERNAL_FUNCTION_LATENCY.labels(function="get_init_data_for_agent").time(): init_task_data = await loop.run_in_executor( None, partial( @@ -285,9 +273,7 @@ async def _assign_unit_to_agent( non_null_agents = [a for a in agents if a is not None] # Launch the backend for this assignment registered_agents = [ - self.agents[a.get_agent_id()] - for a in non_null_agents - if a is not None + self.agents[a.get_agent_id()] for a in non_null_agents if a is not None ] live_run.task_runner.execute_assignment(assignment, registered_agents) @@ -323,33 +309,23 @@ async def register_agent_from_onboarding(self, onboarding_agent: "OnboardingAgen ) assert blueprint.onboarding_qualification_name is not None - worker.grant_qualification( - blueprint.onboarding_qualification_name, int(worker_passed) - ) + worker.grant_qualification(blueprint.onboarding_qualification_name, int(worker_passed)) if not worker_passed: ONBOARDING_OUTCOMES.labels(outcome="failed").inc() - worker.grant_qualification( - blueprint.onboarding_failed_name, int(worker_passed) - ) + worker.grant_qualification(blueprint.onboarding_failed_name, int(worker_passed)) onboarding_agent.update_status(AgentState.STATUS_REJECTED) logger.info(f"Onboarding agent {onboarding_id} failed onboarding") else: ONBOARDING_OUTCOMES.labels(outcome="passed").inc() onboarding_agent.update_status(AgentState.STATUS_APPROVED) - logger.info( - f"Onboarding agent {onboarding_id} registered out from onboarding" - ) + logger.info(f"Onboarding agent {onboarding_id} registered out from onboarding") # get the list of tentatively valid units - with EXTERNAL_FUNCTION_LATENCY.labels( - function="get_valid_units_for_worker" - ).time(): + with EXTERNAL_FUNCTION_LATENCY.labels(function="get_valid_units_for_worker").time(): units = await loop.run_in_executor( None, partial(live_run.task_run.get_valid_units_for_worker, worker) ) - with EXTERNAL_FUNCTION_LATENCY.labels( - function="filter_units_for_worker" - ).time(): + with EXTERNAL_FUNCTION_LATENCY.labels(function="filter_units_for_worker").time(): usable_units = await loop.run_in_executor( None, partial(live_run.task_runner.filter_units_for_worker, units, worker), @@ -396,9 +372,7 @@ async def reconnect_agent(self, agent_id: str, request_id: str): task_runner = live_run.task_runner agent = self.get_agent_for_id(agent_id) if agent is None: - logger.info( - f"Looking for reconnecting agent {agent_id} but none found locally" - ) + logger.info(f"Looking for reconnecting agent {agent_id} but none found locally") AGENT_DETAILS_COUNT.labels(response="agent_missing").inc() live_run.client_io.enqueue_agent_details( request_id, @@ -414,9 +388,7 @@ async def reconnect_agent(self, agent_id: str, request_id: str): # Rejected agent should get failed response live_run.client_io.enqueue_agent_details( request_id, - AgentDetails( - failure_reason=WorkerFailureReasons.NOT_QUALIFIED - ).to_dict(), + AgentDetails(failure_reason=WorkerFailureReasons.NOT_QUALIFIED).to_dict(), ) elif agent.get_status() == AgentState.STATUS_DISCONNECT: # Disconnected agent should get missing response @@ -428,10 +400,7 @@ async def reconnect_agent(self, agent_id: str, request_id: str): ) else: blueprint = live_run.blueprint - assert ( - isinstance(blueprint, OnboardingRequired) - and blueprint.use_onboarding - ) + assert isinstance(blueprint, OnboardingRequired) and blueprint.use_onboarding onboard_data = blueprint.get_onboarding_data(worker.db_id) live_run.client_io.enqueue_agent_details( request_id, @@ -443,9 +412,7 @@ async def reconnect_agent(self, agent_id: str, request_id: str): ) else: # TODO(#649) this is IO bound - with EXTERNAL_FUNCTION_LATENCY.labels( - function="get_init_data_for_agent" - ).time(): + with EXTERNAL_FUNCTION_LATENCY.labels(function="get_init_data_for_agent").time(): init_task_data = await loop.run_in_executor( None, partial( @@ -476,13 +443,8 @@ async def _assign_unit_or_qa( # Check screening if isinstance(blueprint, ScreenTaskRequired) and blueprint.use_screening_task: - if ( - blueprint.worker_needs_screening(worker) - and blueprint.should_generate_unit() - ): - with EXTERNAL_FUNCTION_LATENCY.labels( - function="get_screening_unit_data" - ).time(): + if blueprint.worker_needs_screening(worker) and blueprint.should_generate_unit(): + with EXTERNAL_FUNCTION_LATENCY.labels(function="get_screening_unit_data").time(): screening_data = await loop.run_in_executor( None, blueprint.get_screening_unit_data ) @@ -491,9 +453,7 @@ async def _assign_unit_or_qa( assert ( launcher is not None ), "LiveTaskRun must have launcher to use screening tasks" - with EXTERNAL_FUNCTION_LATENCY.labels( - function="launch_screening_unit" - ).time(): + with EXTERNAL_FUNCTION_LATENCY.labels(function="launch_screening_unit").time(): screen_unit = await loop.run_in_executor( None, partial( @@ -511,9 +471,7 @@ async def _assign_unit_or_qa( failure_reason=WorkerFailureReasons.NO_AVAILABLE_UNITS, ).to_dict(), ) - logger.debug( - f"No screening units left for {agent_registration_id}." - ) + logger.debug(f"No screening units left for {agent_registration_id}.") return # Check golds if isinstance(blueprint, UseGoldUnit) and blueprint.use_golds: @@ -549,9 +507,7 @@ async def _assign_unit_or_qa( # Register the correct unit type await self._assign_unit_to_agent(crowd_data, worker, request_id, units) - async def register_agent( - self, crowd_data: Dict[str, Any], worker: "Worker", request_id: str - ): + async def register_agent(self, crowd_data: Dict[str, Any], worker: "Worker", request_id: str): """Process an agent registration packet to register an agent, returning the agent_id""" # Process a new agent logger.debug(f"Registering agent {crowd_data}, {request_id}") @@ -561,9 +517,7 @@ async def register_agent( agent_registration_id = crowd_data["agent_registration_id"] # get the list of tentatively valid units - with EXTERNAL_FUNCTION_LATENCY.labels( - function="get_valid_units_for_worker" - ).time(): + with EXTERNAL_FUNCTION_LATENCY.labels(function="get_valid_units_for_worker").time(): units = task_run.get_valid_units_for_worker(worker) if len(units) == 0: @@ -575,13 +529,9 @@ async def register_agent( failure_reason=WorkerFailureReasons.NO_AVAILABLE_UNITS, ).to_dict(), ) - logger.debug( - f"agent_registration_id {agent_registration_id}, had no valid units." - ) + logger.debug(f"agent_registration_id {agent_registration_id}, had no valid units.") return - with EXTERNAL_FUNCTION_LATENCY.labels( - function="filter_units_for_worker" - ).time(): + with EXTERNAL_FUNCTION_LATENCY.labels(function="filter_units_for_worker").time(): units = await loop.run_in_executor( None, partial(live_run.task_runner.filter_units_for_worker, units, worker), @@ -590,9 +540,7 @@ async def register_agent( blueprint = live_run.blueprint if isinstance(blueprint, OnboardingRequired) and blueprint.use_onboarding: qual_name = blueprint.onboarding_qualification_name - assert ( - qual_name is not None - ), "Cannot be using onboarding and have a null qual" + assert qual_name is not None, "Cannot be using onboarding and have a null qual" if worker.is_disqualified(qual_name): AGENT_DETAILS_COUNT.labels(response="not_qualified").inc() live_run.client_io.enqueue_agent_details( @@ -639,8 +587,7 @@ async def register_agent( ).to_dict(), ) logger.info( - f"{worker} is starting onboarding thread with " - f"onboarding {onboard_agent}." + f"{worker} is starting onboarding thread with " f"onboarding {onboard_agent}." ) async def cleanup_onboarding(): @@ -651,16 +598,12 @@ async def cleanup_onboarding(): ACTIVE_ONBOARDINGS.dec() # Run the onboarding - live_run.task_runner.execute_onboarding( - onboard_agent, cleanup_onboarding - ) + live_run.task_runner.execute_onboarding(onboard_agent, cleanup_onboarding) return await self._assign_unit_or_qa(crowd_data, worker, request_id, units) - async def push_status_update( - self, agent: Union["Agent", "OnboardingAgent"] - ) -> None: + async def push_status_update(self, agent: Union["Agent", "OnboardingAgent"]) -> None: """ Force a status update for a specific agent, pushing the db status to the frontend client diff --git a/mephisto/scripts/local_db/gh_actions/auto_generate_blueprint.py b/mephisto/scripts/local_db/gh_actions/auto_generate_blueprint.py index da21db372..69b0fd8ca 100644 --- a/mephisto/scripts/local_db/gh_actions/auto_generate_blueprint.py +++ b/mephisto/scripts/local_db/gh_actions/auto_generate_blueprint.py @@ -21,10 +21,7 @@ def create_blueprint_info(blueprint_file, arg_dict): if isinstance(item_content, str) else item_content ) - if ( - isinstance(item_to_append, str) - and item_to_append.rfind("mephisto/") != -1 - ): + if isinstance(item_to_append, str) and item_to_append.rfind("mephisto/") != -1: item_to_append = item_to_append[ item_to_append.rfind("mephisto/") : len(item_to_append) ] diff --git a/mephisto/scripts/local_db/gh_actions/auto_generate_provider.py b/mephisto/scripts/local_db/gh_actions/auto_generate_provider.py index e20fb01db..7205038fe 100644 --- a/mephisto/scripts/local_db/gh_actions/auto_generate_provider.py +++ b/mephisto/scripts/local_db/gh_actions/auto_generate_provider.py @@ -19,9 +19,7 @@ def main(): valid_provider_types = get_valid_provider_types() for provider_type in valid_provider_types: provider_file.new_header(level=2, title=provider_type.replace("_", " ")) - args = get_wut_arguments( - ("provider={provider_name}".format(provider_name=provider_type),) - ) + args = get_wut_arguments(("provider={provider_name}".format(provider_name=provider_type),)) arg_dict = args[0] create_blueprint_info(provider_file, arg_dict) diff --git a/mephisto/scripts/local_db/load_data_to_mephisto_db.py b/mephisto/scripts/local_db/load_data_to_mephisto_db.py index 732e2e673..ec8e2df6b 100644 --- a/mephisto/scripts/local_db/load_data_to_mephisto_db.py +++ b/mephisto/scripts/local_db/load_data_to_mephisto_db.py @@ -81,9 +81,7 @@ def main(): # Get or create a task run for this tasks = db.find_tasks() - task_names = [ - t.task_name for t in tasks if t.task_type == BLUEPRINT_TYPE_STATIC_REACT - ] + task_names = [t.task_name for t in tasks if t.task_type == BLUEPRINT_TYPE_STATIC_REACT] print(f"Use an existing run? ") print(f"You have the following existing mock runs:") diff --git a/mephisto/scripts/local_db/remove_accepted_tip.py b/mephisto/scripts/local_db/remove_accepted_tip.py index f516c69ea..bb9a000d6 100644 --- a/mephisto/scripts/local_db/remove_accepted_tip.py +++ b/mephisto/scripts/local_db/remove_accepted_tip.py @@ -41,9 +41,7 @@ def remove_tip_from_tips_file( tips_location = blueprint_task_run_args["tips_location"] does_file_exist = exists(tips_location) if does_file_exist == False: - print( - "\n[red]You do not have a tips.csv file in your task's output directory[/red]" - ) + print("\n[red]You do not have a tips.csv file in your task's output directory[/red]") quit() lines_to_write = [] @@ -110,12 +108,8 @@ def main(): print("") if removal_response == TipsRemovalType.REMOVE.value: - remove_tip_from_tips_file( - accepted_tips_copy, i, unit.get_task_run() - ) - remove_tip_from_metadata( - accepted_tips, accepted_tips_copy, i, unit - ) + remove_tip_from_tips_file(accepted_tips_copy, i, unit.get_task_run()) + remove_tip_from_metadata(accepted_tips, accepted_tips_copy, i, unit) print("Removed tip\n") elif removal_response == TipsRemovalType.KEEP.value: print("Did not remove tip\n") diff --git a/mephisto/scripts/local_db/review_feedback_for_task.py b/mephisto/scripts/local_db/review_feedback_for_task.py index ca7b654d6..501f666be 100644 --- a/mephisto/scripts/local_db/review_feedback_for_task.py +++ b/mephisto/scripts/local_db/review_feedback_for_task.py @@ -44,9 +44,7 @@ def set_feedback_as_reviewed(feedback: List, id: str, unit: Unit) -> None: index_to_modify = get_index_of_value(feedback_ids, id) if assigned_agent is not None: feedback[index_to_modify]["reviewed"] = True - assigned_agent.state.update_metadata( - property_name="feedback", property_value=feedback - ) + assigned_agent.state.update_metadata(property_name="feedback", property_value=feedback) def print_out_reviewed_feedback_elements( diff --git a/mephisto/scripts/local_db/review_tips_for_task.py b/mephisto/scripts/local_db/review_tips_for_task.py index 61bf826e5..c8e528b39 100644 --- a/mephisto/scripts/local_db/review_tips_for_task.py +++ b/mephisto/scripts/local_db/review_tips_for_task.py @@ -56,9 +56,7 @@ def add_row_to_tips_file(task_run: TaskRun, item_to_add: Dict[str, Any]): try: create_tips_file.touch(exist_ok=True) except FileNotFoundError: - print( - "\n[red]Your task folder must have an assets folder in it.[/red]\n" - ) + print("\n[red]Your task folder must have an assets folder in it.[/red]\n") quit() with open(tips_location, "r") as inp, open(tips_location, "a+") as tips_file: @@ -81,9 +79,7 @@ def remove_tip_from_metadata( if assigned_agent is not None: tips_copy.pop(index_to_remove) - assigned_agent.state.update_metadata( - property_name="tips", property_value=tips_copy - ) + assigned_agent.state.update_metadata(property_name="tips", property_value=tips_copy) else: print("[red]An assigned agent was not able to be found for this tip[/red]") quit() @@ -99,9 +95,7 @@ def accept_tip(tips: List, tips_copy: List, i: int, unit: Unit) -> None: if assigned_agent is not None: tips_copy[index_to_update]["accepted"] = True add_row_to_tips_file(unit.get_task_run(), tips_copy[index_to_update]) - assigned_agent.state.update_metadata( - property_name="tips", property_value=tips_copy - ) + assigned_agent.state.update_metadata(property_name="tips", property_value=tips_copy) def main(): @@ -176,9 +170,7 @@ def main(): bonus, reason, unit ) if bonus_successfully_paid: - print( - "\n[green]Bonus Successfully Paid![/green]\n" - ) + print("\n[green]Bonus Successfully Paid![/green]\n") else: print( "\n[red]There was an error when paying out your bonus[/red]\n" diff --git a/mephisto/scripts/mturk/cleanup.py b/mephisto/scripts/mturk/cleanup.py index 2826db066..c989af2b6 100644 --- a/mephisto/scripts/mturk/cleanup.py +++ b/mephisto/scripts/mturk/cleanup.py @@ -43,9 +43,7 @@ def main(): outstanding_hit_types = get_outstanding_hits(client) num_hit_types = len(outstanding_hit_types.keys()) - sum_hits = sum( - [len(outstanding_hit_types[x]) for x in outstanding_hit_types.keys()] - ) + sum_hits = sum([len(outstanding_hit_types[x]) for x in outstanding_hit_types.keys()]) all_hits: List[Dict[str, Any]] = [] for hit_type in outstanding_hit_types.keys(): @@ -80,17 +78,14 @@ def main(): print(f"LAUNCH TIME: {creation_time_str}") print(f"HIT COUNT: {len(outstanding_hit_types[hit_type])}") should_clear = input( - "Should we cleanup this hit type? (y)es for yes, anything else for no: " - "\n>> " + "Should we cleanup this hit type? (y)es for yes, anything else for no: " "\n>> " ) if should_clear.lower().startswith("y"): use_hits += outstanding_hit_types[hit_type] elif run_type.lower().startswith("a"): use_hits = all_hits elif run_type.lower().startswith("o"): - old_cutoff = datetime.now(all_hits[0]["CreationTime"].tzinfo) - timedelta( - days=14 - ) + old_cutoff = datetime.now(all_hits[0]["CreationTime"].tzinfo) - timedelta(days=14) use_hits = [h for h in all_hits if h["CreationTime"] < old_cutoff] else: run_type = input("Options are (t)itle, (o)ld, or (a)ll:\n>> ") diff --git a/mephisto/scripts/mturk/identify_broken_units.py b/mephisto/scripts/mturk/identify_broken_units.py index 7a4f48229..b5b532d30 100644 --- a/mephisto/scripts/mturk/identify_broken_units.py +++ b/mephisto/scripts/mturk/identify_broken_units.py @@ -36,9 +36,7 @@ def main(): and u.get_assigned_agent() is not None ] completed_timeout_units = [ - u - for u in completed_agented_units - if u.get_assigned_agent().get_status() == "timeout" + u for u in completed_agented_units if u.get_assigned_agent().get_status() == "timeout" ] if len(completed_agentless_units) == 0 and len(completed_timeout_units) == 0: @@ -52,9 +50,7 @@ def main(): ) print(completed_timeout_units[-5:]) - agents = db.find_agents(task_run_id=TASK_RUN) + db.find_agents( - task_run_id=TASK_RUN - 1 - ) + agents = db.find_agents(task_run_id=TASK_RUN) + db.find_agents(task_run_id=TASK_RUN - 1) requester = units[0].get_requester() client = requester._get_client(requester._requester_name) @@ -72,9 +68,7 @@ def main(): print(f"Querying assignments for the {len(hits)} tasks.") - task_assignments_uf = [ - get_assignments_for_hit(client, h["HITId"]) for h in task_hits - ] + task_assignments_uf = [get_assignments_for_hit(client, h["HITId"]) for h in task_hits] task_assignments = [t[0] for t in task_assignments_uf if len(t) != 0] print(f"Found {len(task_assignments)} assignments to map.") @@ -88,13 +82,9 @@ def main(): worker_id_to_agents[worker_id].append(a) print("Constructing hit-id to unit mapping for completed...") - hit_ids_to_unit = { - u.get_mturk_hit_id(): u for u in units if u.get_mturk_hit_id() is not None - } + hit_ids_to_unit = {u.get_mturk_hit_id(): u for u in units if u.get_mturk_hit_id() is not None} - unattributed_assignments = [ - t for t in task_assignments if t["HITId"] not in hit_ids_to_unit - ] + unattributed_assignments = [t for t in task_assignments if t["HITId"] not in hit_ids_to_unit] print(f"Found {len(unattributed_assignments)} assignments with no mapping!") @@ -113,9 +103,7 @@ def main(): if units_agent is None or units_agent.db_id != agent.db_id: continue - print( - f"Agent {agent} would be a good candidate to reconcile {assignment['HITId']}" - ) + print(f"Agent {agent} would be a good candidate to reconcile {assignment['HITId']}") # TODO(WISH) automate the below print( "You can do this manually by selecting the best candidate, then " diff --git a/mephisto/scripts/mturk/launch_makeup_hits.py b/mephisto/scripts/mturk/launch_makeup_hits.py index 29c1539ca..e45d8770b 100644 --- a/mephisto/scripts/mturk/launch_makeup_hits.py +++ b/mephisto/scripts/mturk/launch_makeup_hits.py @@ -134,14 +134,12 @@ def main(): qualification = make_qualification_dict(qual_name, QUAL_EXISTS, None) qual_map = requester.datastore.get_qualification_mapping(qual_name) if qual_map is None: - qualification[ - "QualificationTypeId" - ] = requester._create_new_mturk_qualification(qual_name) + qualification["QualificationTypeId"] = requester._create_new_mturk_qualification( + qual_name + ) else: qualification["QualificationTypeId"] = qual_map["mturk_qualification_id"] - give_worker_qualification( - client, worker_id, qualification["QualificationTypeId"] - ) + give_worker_qualification(client, worker_id, qualification["QualificationTypeId"]) # Create the task run for this HIT print(f"Creating task run and data model components for this HIT") diff --git a/mephisto/scripts/mturk/print_outstanding_hit_status.py b/mephisto/scripts/mturk/print_outstanding_hit_status.py index e71add3e2..eec5d5982 100644 --- a/mephisto/scripts/mturk/print_outstanding_hit_status.py +++ b/mephisto/scripts/mturk/print_outstanding_hit_status.py @@ -23,9 +23,7 @@ def main(): task_run = TaskRun.get(db, task_run_id) requester = task_run.get_requester() if not isinstance(requester, MTurkRequester): - print( - "Must be checking a task launched on MTurk, this one uses the following requester:" - ) + print("Must be checking a task launched on MTurk, this one uses the following requester:") print(requester) exit(0) diff --git a/mephisto/scripts/mturk/soft_block_workers_by_mturk_id.py b/mephisto/scripts/mturk/soft_block_workers_by_mturk_id.py index 2036538f0..ee9c6b46f 100644 --- a/mephisto/scripts/mturk/soft_block_workers_by_mturk_id.py +++ b/mephisto/scripts/mturk/soft_block_workers_by_mturk_id.py @@ -27,9 +27,7 @@ def main(): break workers_to_block.append(new_id) - direct_soft_block_mturk_workers( - db, workers_to_block, soft_block_qual_name, requester_name - ) + direct_soft_block_mturk_workers(db, workers_to_block, soft_block_qual_name, requester_name) if __name__ == "__main__": diff --git a/mephisto/tools/data_browser.py b/mephisto/tools/data_browser.py index 47af63080..0f7dde642 100644 --- a/mephisto/tools/data_browser.py +++ b/mephisto/tools/data_browser.py @@ -46,17 +46,13 @@ def _get_units_for_task_runs(self, task_runs: List[TaskRun]) -> List[Unit]: Return a list of all Units in a terminal completed state from all the provided TaskRuns. """ - return self.collect_matching_units_from_task_runs( - task_runs, AssignmentState.completed() - ) + return self.collect_matching_units_from_task_runs(task_runs, AssignmentState.completed()) def _get_all_units_for_task_runs(self, task_runs: List[TaskRun]) -> List[Unit]: """ Does the same as _get_units_for_task_runs except that it includes the EXPIRED state """ - return self.collect_matching_units_from_task_runs( - task_runs, AssignmentState.final_agent() - ) + return self.collect_matching_units_from_task_runs(task_runs, AssignmentState.final_agent()) def get_task_name_list(self) -> List[str]: return [task.task_name for task in self.db.find_tasks()] @@ -94,9 +90,7 @@ def get_data_from_unit(self, unit: Unit) -> Dict[str, Any]: relevant assignment this unit was a part of. """ agent = unit.get_assigned_agent() - assert ( - agent is not None - ), f"Trying to get completed data from unassigned unit {unit}" + assert agent is not None, f"Trying to get completed data from unassigned unit {unit}" return { "worker_id": agent.worker_id, "unit_id": unit.db_id, @@ -121,9 +115,7 @@ def get_workers_with_qualification(self, qualification_name: str) -> List[Worker ) return [Worker.get(self.db, qual.worker_id) for qual in qualifieds] - def get_metadata_property_from_task_name( - self, task_name: str, property_name: str - ) -> List[Any]: + def get_metadata_property_from_task_name(self, task_name: str, property_name: str) -> List[Any]: """Returns all metadata for a task by going through its agents""" units = self.get_all_units_for_task_name(task_name=task_name) diff --git a/mephisto/tools/examine_utils.py b/mephisto/tools/examine_utils.py index d0fbc82c5..1f2a8a6d2 100644 --- a/mephisto/tools/examine_utils.py +++ b/mephisto/tools/examine_utils.py @@ -225,18 +225,14 @@ def run_examine_by_worker( if apply_all_decision is not None: decision = apply_all_decision else: - decision = input( - "Do you want to accept this work? (a)ccept, (r)eject, (p)ass: " - ) + decision = input("Do you want to accept this work? (a)ccept, (r)eject, (p)ass: ") while decision.lower() not in options: decision = input( "Decision must be one of a, p, r. Use CAPS to apply to all remaining for worker: " ) agent = unit.get_assigned_agent() - assert ( - agent is not None - ), f"Can't make decision on None agent... issue with {unit}" + assert agent is not None, f"Can't make decision on None agent... issue with {unit}" if decision.lower() == "a": agent.approve_work() if decision == "A" and approve_qualification is not None: @@ -248,9 +244,7 @@ def run_examine_by_worker( elif decision.lower() == "p": agent.soft_reject_work() if apply_all_decision is None and block_qualification is not None: - should_soft_block = input( - "Do you want to soft block this worker? (y)es/(n)o: " - ) + should_soft_block = input("Do you want to soft block this worker? (y)es/(n)o: ") if should_soft_block.lower() in ["y", "yes"]: worker.grant_qualification(block_qualification, 1) elif decision.lower() == "v": @@ -265,9 +259,7 @@ def run_examine_by_worker( else: # decision = 'r' if apply_all_decision is None: reason = input("Why are you rejecting this work? ") - should_block = input( - "Do you want to hard block this worker? (y)es/(n)o: " - ) + should_block = input("Do you want to hard block this worker? (y)es/(n)o: ") if should_block.lower() in ["y", "yes"]: block_reason = input("Why permanently block this worker? ") worker.block_worker(block_reason, unit=unit) diff --git a/mephisto/tools/scripts.py b/mephisto/tools/scripts.py index 334fd9c96..9703075fe 100644 --- a/mephisto/tools/scripts.py +++ b/mephisto/tools/scripts.py @@ -95,9 +95,7 @@ def task_script( if config is not None: used_config = config else: - assert ( - default_config_file is not None - ), "Must provide one of config or default_config_file" + assert default_config_file is not None, "Must provide one of config or default_config_file" used_config = build_default_task_config(default_config_file) register_script_config(name="taskconfig", module=used_config) @@ -114,9 +112,7 @@ def process_config_and_run_main(cfg: "DictConfig"): operator.shutdown() return ret_val - absolute_config_path = os.path.abspath( - os.path.join(get_run_file_dir(), config_path) - ) + absolute_config_path = os.path.abspath(os.path.join(get_run_file_dir(), config_path)) hydra_wrapper = hydra.main( config_path=absolute_config_path, config_name="taskconfig", @@ -179,9 +175,7 @@ def augment_config_from_db(script_cfg: DictConfig, db: "MephistoDB") -> DictConf elif len(reqs) == 1: req = reqs[0] requester_name = req.requester_name - print( - f"Found one `{provider_type}` requester to launch with: {requester_name}" - ) + print(f"Found one `{provider_type}` requester to launch with: {requester_name}") else: req = reqs[-1] requester_name = req.requester_name @@ -210,9 +204,7 @@ def augment_config_from_db(script_cfg: DictConfig, db: "MephistoDB") -> DictConf if provider_type in ["mturk"]: try_prerun_cleanup(db, cfg.provider.requester_name) - input( - f"This task is going to launch live on {provider_type}, press enter to continue: " - ) + input(f"This task is going to launch live on {provider_type}, press enter to continue: ") if provider_type in ["mturk_sandbox", "mturk"] and architect_type not in [ "heroku", "ec2", @@ -282,8 +274,7 @@ def build_custom_bundle( packages_installed = subprocess.call(["npm", "install"]) if packages_installed != 0: raise Exception( - "please make sure npm is installed, otherwise view " - "the above error for more info." + "please make sure npm is installed, otherwise view " "the above error for more info." ) if post_install_script is not None and len(post_install_script) > 0: diff --git a/mephisto/utils/metrics.py b/mephisto/utils/metrics.py index e038b77a1..a3f861035 100644 --- a/mephisto/utils/metrics.py +++ b/mephisto/utils/metrics.py @@ -145,16 +145,12 @@ def launch_prometheus_server(args: Optional["DictConfig"] = None) -> bool: except requests.exceptions.ConnectionError: is_ok = False if not is_ok: - logger.warning( - "Prometheus PID existed, but server doesn't appear to be up." - ) + logger.warning("Prometheus PID existed, but server doesn't appear to be up.") if _server_process_running(_get_pid_from_file(PROMETHEUS_PID_FILE)): logger.warning( "Prometheus server appears to be running though! exiting as unsure what to do..." ) - raise InaccessiblePrometheusServer( - "Prometheus server running but inaccessible" - ) + raise InaccessiblePrometheusServer("Prometheus server running but inaccessible") else: logger.warning( "Clearing prometheus pid as the server isn't running. " @@ -245,9 +241,7 @@ def get_dash_url(args: Optional["DictConfig"] = None): return f"localhost:3032{output[0]['url']}" -def shutdown_prometheus_server( - args: Optional["DictConfig"] = None, expect_exists=False -): +def shutdown_prometheus_server(args: Optional["DictConfig"] = None, expect_exists=False): """ Shutdown the prometheus server """ diff --git a/mephisto/utils/qualifications.py b/mephisto/utils/qualifications.py index c1bb18a4e..dfe2bf8d3 100644 --- a/mephisto/utils/qualifications.py +++ b/mephisto/utils/qualifications.py @@ -72,15 +72,11 @@ def as_valid_qualification_dict(qual_dict: Dict[str, Any]) -> Dict[str, Any]: ] for key in required_keys: if key not in qual_dict: - raise AssertionError( - f"Required key {key} not in qualification dict {qual_dict}" - ) + raise AssertionError(f"Required key {key} not in qualification dict {qual_dict}") qual_name = qual_dict["qualification_name"] if type(qual_name) is not str or len(qual_name) == 0: - raise AssertionError( - f"Qualification name '{qual_name}' is not a string with length > 0" - ) + raise AssertionError(f"Qualification name '{qual_name}' is not a string with length > 0") comparator = qual_dict["comparator"] if comparator not in SUPPORTED_COMPARATORS: diff --git a/mephisto/utils/testing.py b/mephisto/utils/testing.py index 0e86ca711..9a0051131 100644 --- a/mephisto/utils/testing.py +++ b/mephisto/utils/testing.py @@ -91,9 +91,7 @@ def get_test_task_run(db: MephistoDB) -> str: task_name, task_id = get_test_task(db) requester_name, requester_id = get_test_requester(db) init_params = OmegaConf.to_yaml(OmegaConf.structured(MOCK_CONFIG)) - return db.new_task_run( - task_id, requester_id, json.dumps(init_params), "mock", "mock" - ) + return db.new_task_run(task_id, requester_id, json.dumps(init_params), "mock", "mock") def get_test_assignment(db: MephistoDB) -> str: diff --git a/scripts/check_npm_package_versions.py b/scripts/check_npm_package_versions.py index 3b5b8a2ba..c9e8570fa 100644 --- a/scripts/check_npm_package_versions.py +++ b/scripts/check_npm_package_versions.py @@ -25,9 +25,7 @@ def run_check(): all_success = True for pkg in CHECK_PACKAGES: package_location = os.path.join(ROOT_DIR, "packages", pkg, "package.json") - assert os.path.exists( - package_location - ), f"Can't find package {pkg} at {package_location}" + assert os.path.exists(package_location), f"Can't find package {pkg} at {package_location}" with open(package_location) as package_json: version = json.load(package_json)["version"] diff --git a/scripts/sync_mephisto_task.py b/scripts/sync_mephisto_task.py index bae26aa76..68454fdf3 100644 --- a/scripts/sync_mephisto_task.py +++ b/scripts/sync_mephisto_task.py @@ -40,9 +40,7 @@ def run_replace(): print(f"Detected mephisto-task version '{version}' at '{MEPHISTO_TASK_PACKAGE}'") if is_check_mode: - print( - f"Checking all dependent files are using mephisto-task version '{version}'...\n" - ) + print(f"Checking all dependent files are using mephisto-task version '{version}'...\n") else: print(f"Syncing all dependent files to mephisto-task version '{version}'...\n") output = f'CURR_MEPHISTO_TASK_VERSION = "{version}"' diff --git a/test/abstractions/architects/test_local_architect.py b/test/abstractions/architects/test_local_architect.py index fcad3dd1b..ac196da33 100644 --- a/test/abstractions/architects/test_local_architect.py +++ b/test/abstractions/architects/test_local_architect.py @@ -62,9 +62,7 @@ def server_is_cleaned(self, build_dir: str) -> bool: def server_is_shutdown(self) -> bool: """Ensure process is no longer running""" assert self.curr_architect is not None, "No architect to check" - assert ( - self.curr_architect.server_process is not None - ), "architect has no server process" + assert self.curr_architect.server_process is not None, "architect has no server process" return self.curr_architect.server_process.returncode is not None # TODO(#102) maybe a test where we need to re-instance an architect? @@ -75,10 +73,7 @@ def tearDown(self) -> None: if self.curr_architect is not None: if self.curr_architect.running_dir is not None: sh.rm(shlex.split("-rf " + self.curr_architect.running_dir)) - if ( - self.curr_architect.server_process is not None - and not self.server_is_shutdown() - ): + if self.curr_architect.server_process is not None and not self.server_is_shutdown(): self.curr_architect.server_process.terminate() self.curr_architect.server_process.wait() diff --git a/test/abstractions/blueprints/test_mixin_core.py b/test/abstractions/blueprints/test_mixin_core.py index 28f830040..4026865f9 100644 --- a/test/abstractions/blueprints/test_mixin_core.py +++ b/test/abstractions/blueprints/test_mixin_core.py @@ -44,9 +44,7 @@ def init_mixin_config( return @classmethod - def assert_mixin_args( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ) -> None: + def assert_mixin_args(cls, args: "DictConfig", shared_state: "SharedTaskState") -> None: return @classmethod @@ -81,9 +79,7 @@ def init_mixin_config( self.mixin_init_calls = 1 @classmethod - def assert_mixin_args( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ) -> None: + def assert_mixin_args(cls, args: "DictConfig", shared_state: "SharedTaskState") -> None: assert args.blueprint.arg1 == 0, "Was not the default value of arg1" @classmethod @@ -118,9 +114,7 @@ def init_mixin_config( self.mixin_init_calls = 1 @classmethod - def assert_mixin_args( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ) -> None: + def assert_mixin_args(cls, args: "DictConfig", shared_state: "SharedTaskState") -> None: assert args.blueprint.arg2 == 0, "Was not the default value of arg2" @classmethod @@ -147,9 +141,7 @@ class ComposedMixin(MockBlueprintMixin1, MockBlueprintMixin2): mixin_init_calls: int @classmethod - def assert_mixin_args( - cls, args: "DictConfig", shared_state: "SharedTaskState" - ) -> None: + def assert_mixin_args(cls, args: "DictConfig", shared_state: "SharedTaskState") -> None: MockBlueprintMixin1.assert_mixin_args(args, shared_state) MockBlueprintMixin2.assert_mixin_args(args, shared_state) @@ -193,9 +185,7 @@ def get_initialization_data(self): shared_state = TestBlueprint.SharedStateClass() cfg = self.get_structured_config(args) - with self.assertRaises( - AttributeError, msg="Undefined mixin classes should fail here" - ): + with self.assertRaises(AttributeError, msg="Undefined mixin classes should fail here"): @BrokenMixin.mixin_args_and_state class TestBlueprint(BrokenMixin, Blueprint): @@ -234,9 +224,7 @@ def get_initialization_data(self): cfg = self.get_structured_config(args) TestBlueprint.assert_task_args(cfg, shared_state) blueprint = TestBlueprint(self.task_run, cfg, shared_state) - self.assertEqual( - blueprint.mixin_init_calls, 1, "More than one mixin init call!" - ) + self.assertEqual(blueprint.mixin_init_calls, 1, "More than one mixin init call!") # Working mixin using the decorator @MockBlueprintMixin1.mixin_args_and_state @@ -249,9 +237,7 @@ def get_initialization_data(self): cfg = self.get_structured_config(args) TestBlueprint.assert_task_args(cfg, shared_state) blueprint = TestBlueprint(self.task_run, cfg, shared_state) - self.assertEqual( - blueprint.mixin_init_calls, 1, "More than one mixin init call!" - ) + self.assertEqual(blueprint.mixin_init_calls, 1, "More than one mixin init call!") def test_mixin_multi_inheritence(self): @MockBlueprintMixin1.mixin_args_and_state @@ -268,20 +254,14 @@ def get_initialization_data(self): self.assertEqual(blueprint.mixin_init_calls, 2, "Should have 2 mixin calls") # Ensure qualifications are correct - required_quals = DoubleMixinBlueprint.get_required_qualifications( - args, shared_state - ) - self.assertEqual( - len(BlueprintMixin.extract_unique_mixins(DoubleMixinBlueprint)), 2 - ) + required_quals = DoubleMixinBlueprint.get_required_qualifications(args, shared_state) + self.assertEqual(len(BlueprintMixin.extract_unique_mixins(DoubleMixinBlueprint)), 2) qual_names = [q["qual_name"] for q in required_quals] self.assertIn(MockBlueprintMixin1.MOCK_QUAL_NAME, qual_names) self.assertIn(MockBlueprintMixin2.MOCK_QUAL_NAME, qual_names) # Check functionality of important helpers - self.assertEqual( - len(BlueprintMixin.extract_unique_mixins(DoubleMixinBlueprint)), 2 - ) + self.assertEqual(len(BlueprintMixin.extract_unique_mixins(DoubleMixinBlueprint)), 2) # Ensure failures work for each of the arg failures shared_state = DoubleMixinBlueprint.SharedStateClass() @@ -309,19 +289,13 @@ def get_initialization_data(self): self.assertEqual(blueprint.mixin_init_calls, 1, "Should have 1 mixin call") # Ensure qualifications are correct - required_quals = ComposedBlueprint.get_required_qualifications( - args, shared_state - ) - self.assertEqual( - len(BlueprintMixin.extract_unique_mixins(ComposedBlueprint)), 1 - ) + required_quals = ComposedBlueprint.get_required_qualifications(args, shared_state) + self.assertEqual(len(BlueprintMixin.extract_unique_mixins(ComposedBlueprint)), 1) qual_names = [q["qual_name"] for q in required_quals] self.assertIn(ComposedBlueprint.MOCK_QUAL_NAME, qual_names) # Check functionality of important helpers - self.assertEqual( - len(BlueprintMixin.extract_unique_mixins(ComposedBlueprint)), 1 - ) + self.assertEqual(len(BlueprintMixin.extract_unique_mixins(ComposedBlueprint)), 1) # Ensure failures work for each of the arg failures shared_state = ComposedBlueprint.SharedStateClass() diff --git a/test/abstractions/blueprints/test_mock_blueprint.py b/test/abstractions/blueprints/test_mock_blueprint.py index 12834b12c..49145c24a 100644 --- a/test/abstractions/blueprints/test_mock_blueprint.py +++ b/test/abstractions/blueprints/test_mock_blueprint.py @@ -97,9 +97,7 @@ def get_test_assignment(self) -> Assignment: Agent = MockAgent.get(self.db, agent_id) return assign - def assignment_is_tracked( - self, task_runner: TaskRunner, assignment: Assignment - ) -> bool: + def assignment_is_tracked(self, task_runner: TaskRunner, assignment: Assignment) -> bool: """ Return whether or not this task is currently being tracked (run) by the given task runner. This should be false unless diff --git a/test/abstractions/providers/mturk_sandbox/test_mturk_provider.py b/test/abstractions/providers/mturk_sandbox/test_mturk_provider.py index 5f1c95ae6..5522e551f 100644 --- a/test/abstractions/providers/mturk_sandbox/test_mturk_provider.py +++ b/test/abstractions/providers/mturk_sandbox/test_mturk_provider.py @@ -78,21 +78,15 @@ def test_grant_and_revoke_qualifications(self) -> None: qualification_name = f"mephisto_test_qualification_{int(time.time())}" extended_qualification_name = f"{qualification_name}_sandbox" - qual_mapping = worker.datastore.get_qualification_mapping( - extended_qualification_name - ) + qual_mapping = worker.datastore.get_qualification_mapping(extended_qualification_name) self.assertIsNone(qual_mapping) mephisto_qual_id = db.make_qualification(qualification_name) - self.assertTrue( - worker.grant_qualification(qualification_name), "Qualification not granted" - ) + self.assertTrue(worker.grant_qualification(qualification_name), "Qualification not granted") # ensure the qualification exists - qual_mapping = worker.datastore.get_qualification_mapping( - extended_qualification_name - ) + qual_mapping = worker.datastore.get_qualification_mapping(extended_qualification_name) self.assertIsNotNone(qual_mapping) assert qual_mapping is not None, "For typing, already asserted this isn't None" @@ -112,15 +106,11 @@ def cleanup_qualification(): worker.revoke_qualification(qualification_name), "Qualification not revoked" ) - owned, found_qual = find_qualification( - client, qual_mapping["mturk_qualification_name"] - ) + owned, found_qual = find_qualification(client, qual_mapping["mturk_qualification_name"]) start_time = time.time() while found_qual is None: time.sleep(1) - owned, found_qual = find_qualification( - client, qual_mapping["mturk_qualification_name"] - ) + owned, found_qual = find_qualification(client, qual_mapping["mturk_qualification_name"]) self.assertFalse( time.time() - start_time > 20, "MTurk did not register qualification creation", @@ -132,9 +122,7 @@ def cleanup_qualification(): # TODO(#97) assert the worker does not have the qualification - self.assertTrue( - worker.grant_qualification(qualification_name), "Qualification not granted" - ) + self.assertTrue(worker.grant_qualification(qualification_name), "Qualification not granted") # TODO(#97) assert that the worker has the qualification @@ -144,21 +132,15 @@ def cleanup_qualification(): # TODO(#97) assert the worker no longer has the qualification again - self.assertFalse( - worker.revoke_qualification(qualification_name), "Can't revoke qual twice" - ) + self.assertFalse(worker.revoke_qualification(qualification_name), "Can't revoke qual twice") db.delete_qualification(qualification_name) - owned, found_qual = find_qualification( - client, qual_mapping["mturk_qualification_name"] - ) + owned, found_qual = find_qualification(client, qual_mapping["mturk_qualification_name"]) start_time = time.time() while found_qual is not None: time.sleep(1) - owned, found_qual = find_qualification( - client, qual_mapping["mturk_qualification_name"] - ) + owned, found_qual = find_qualification(client, qual_mapping["mturk_qualification_name"]) self.assertFalse( time.time() - start_time > 20, "MTurk did not register qualification deletion", diff --git a/test/abstractions/providers/prolific/test_prolific_utils.py b/test/abstractions/providers/prolific/test_prolific_utils.py index 653f53d52..042ddcfcf 100644 --- a/test/abstractions/providers/prolific/test_prolific_utils.py +++ b/test/abstractions/providers/prolific/test_prolific_utils.py @@ -26,25 +26,25 @@ from mephisto.data_model.requester import RequesterArgs from mephisto.data_model.task_run import TaskRunArgs -MOCK_PROLIFIC_CONFIG_DIR = '/tmp/' -MOCK_PROLIFIC_CONFIG_PATH = '/tmp/test_conf_credentials' +MOCK_PROLIFIC_CONFIG_DIR = "/tmp/" +MOCK_PROLIFIC_CONFIG_PATH = "/tmp/test_conf_credentials" @dataclass class MockProlificRequesterArgs(RequesterArgs): name: str = field( - default='prolific', + default="prolific", ) api_key: str = field( - default='prolific', + default="prolific", ) mock_task_run_args = TaskRunArgs( - task_title='title', - task_description='This is a description', + task_title="title", + task_description="This is a description", task_reward=0.3, - task_tags='1,2,3', + task_tags="1,2,3", task_lifetime_in_seconds=1, ) @@ -52,44 +52,45 @@ class MockProlificRequesterArgs(RequesterArgs): @pytest.mark.prolific class TestProlificUtils(unittest.TestCase): """Unit testing for Prolific Utils""" + @staticmethod def remove_credentials_file(): if os.path.exists(MOCK_PROLIFIC_CONFIG_PATH): os.remove(MOCK_PROLIFIC_CONFIG_PATH) - @patch('mephisto.abstractions.providers.prolific.api.users.Users.me') + @patch("mephisto.abstractions.providers.prolific.api.users.Users.me") def test_check_credentials_true(self, mock_prolific_users_me, *args): - mock_prolific_users_me.return_value = User(id='test') + mock_prolific_users_me.return_value = User(id="test") result = check_credentials() self.assertTrue(result) - @patch('mephisto.abstractions.providers.prolific.api.users.Users.me') + @patch("mephisto.abstractions.providers.prolific.api.users.Users.me") def test_check_credentials_false(self, mock_prolific_users_me, *args): mock_prolific_users_me.side_effect = ProlificRequestError() result = check_credentials() self.assertFalse(result) @patch( - 'mephisto.abstractions.providers.prolific.prolific_utils.CREDENTIALS_CONFIG_DIR', + "mephisto.abstractions.providers.prolific.prolific_utils.CREDENTIALS_CONFIG_DIR", MOCK_PROLIFIC_CONFIG_DIR, ) @patch( - 'mephisto.abstractions.providers.prolific.prolific_utils.CREDENTIALS_CONFIG_PATH', + "mephisto.abstractions.providers.prolific.prolific_utils.CREDENTIALS_CONFIG_PATH", MOCK_PROLIFIC_CONFIG_PATH, ) def test_setup_credentials(self, *args): self.remove_credentials_file() self.assertFalse(os.path.exists(MOCK_PROLIFIC_CONFIG_PATH)) cfg = MockProlificRequesterArgs() - setup_credentials('name', cfg) + setup_credentials("name", cfg) self.assertTrue(os.path.exists(MOCK_PROLIFIC_CONFIG_PATH)) self.remove_credentials_file() - @patch('mephisto.abstractions.providers.prolific.api.participant_groups.ParticipantGroups.list') + @patch("mephisto.abstractions.providers.prolific.api.participant_groups.ParticipantGroups.list") def test_find_qualification_success(self, mock_participant_groups_list, *args): prolific_project_id = uuid4().hex[:24] - qualification_name = 'test' - qualification_description = 'test' + qualification_name = "test" + qualification_description = "test" expected_qualification_id = uuid4().hex[:24] mock_participant_groups_list.return_value = [ ParticipantGroup( @@ -102,44 +103,49 @@ def test_find_qualification_success(self, mock_participant_groups_list, *args): _, q = _find_qualification(prolific_api, prolific_project_id, qualification_name) self.assertEqual(q.id, expected_qualification_id) - @patch('mephisto.abstractions.providers.prolific.api.participant_groups.ParticipantGroups.list') + @patch("mephisto.abstractions.providers.prolific.api.participant_groups.ParticipantGroups.list") def test_find_qualification_no_qualification(self, mock_participant_groups_list, *args): prolific_project_id = uuid4().hex[:24] - qualification_name = 'test' + qualification_name = "test" mock_participant_groups_list.return_value = [] result = _find_qualification(prolific_api, prolific_project_id, qualification_name) self.assertEqual(result, (True, None)) - @patch('mephisto.abstractions.providers.prolific.api.participant_groups.ParticipantGroups.list') + @patch("mephisto.abstractions.providers.prolific.api.participant_groups.ParticipantGroups.list") def test_find_qualification_error(self, mock_participant_groups_list, *args): prolific_project_id = uuid4().hex[:24] - qualification_name = 'test' - exception_message = 'Error' + qualification_name = "test" + exception_message = "Error" mock_participant_groups_list.side_effect = ProlificRequestError(exception_message) with self.assertRaises(ProlificRequestError) as cm: _find_qualification(prolific_api, prolific_project_id, qualification_name) self.assertEqual(cm.exception.message, exception_message) - @patch('mephisto.abstractions.providers.prolific.prolific_utils._find_qualification') + @patch("mephisto.abstractions.providers.prolific.prolific_utils._find_qualification") def test_find_or_create_qualification_found_one(self, mock_find_qualification, *args): prolific_project_id = uuid4().hex[:24] - qualification_name = 'test' + qualification_name = "test" expected_qualification_id = uuid4().hex[:24] mock_find_qualification.return_value = (True, expected_qualification_id) result = find_or_create_qualification( - prolific_api, prolific_project_id, qualification_name, + prolific_api, + prolific_project_id, + qualification_name, ) self.assertEqual(result.id, expected_qualification_id) @patch( - 'mephisto.abstractions.providers.prolific.api.participant_groups.ParticipantGroups.create' + "mephisto.abstractions.providers.prolific.api.participant_groups.ParticipantGroups.create" ) - @patch('mephisto.abstractions.providers.prolific.prolific_utils._find_qualification') + @patch("mephisto.abstractions.providers.prolific.prolific_utils._find_qualification") def test_find_or_create_qualification_created_new( - self, mock_find_qualification, mock_participant_groups_create, *args, + self, + mock_find_qualification, + mock_participant_groups_create, + *args, ): - qualification_name = 'test' - qualification_description = 'test' + qualification_name = "test" + qualification_description = "test" expected_qualification_id = uuid4().hex[:24] mock_find_qualification.return_value = (False, None) mock_participant_groups_create.return_value = ParticipantGroup( @@ -148,36 +154,43 @@ def test_find_or_create_qualification_created_new( description=qualification_description, ) result = find_or_create_qualification( - prolific_api, qualification_name, qualification_description, + prolific_api, + qualification_name, + qualification_description, ) self.assertEqual(result.id, expected_qualification_id) @patch( - 'mephisto.abstractions.providers.prolific.api.participant_groups.ParticipantGroups.create' + "mephisto.abstractions.providers.prolific.api.participant_groups.ParticipantGroups.create" ) - @patch('mephisto.abstractions.providers.prolific.prolific_utils._find_qualification') + @patch("mephisto.abstractions.providers.prolific.prolific_utils._find_qualification") def test_find_or_create_qualification_error( - self, mock_find_qualification, mock_participant_groups_create, *args, + self, + mock_find_qualification, + mock_participant_groups_create, + *args, ): - qualification_name = 'test' - qualification_description = 'test' + qualification_name = "test" + qualification_description = "test" mock_find_qualification.return_value = (False, None) - exception_message = 'Error' + exception_message = "Error" mock_participant_groups_create.side_effect = ProlificRequestError(exception_message) with self.assertRaises(ProlificRequestError) as cm: find_or_create_qualification( - prolific_api, qualification_name, qualification_description, + prolific_api, + qualification_name, + qualification_description, ) self.assertEqual(cm.exception.message, exception_message) - @patch('mephisto.abstractions.providers.prolific.api.studies.Studies.create') + @patch("mephisto.abstractions.providers.prolific.api.studies.Studies.create") def test_create_study_success(self, mock_study_create, *args): project_id = uuid4().hex[:24] expected_task_id = uuid4().hex[:24] mock_study_create.return_value = Study( project=project_id, id=expected_task_id, - name='test', + name="test", ) study = create_study( client=prolific_api, @@ -186,10 +199,10 @@ def test_create_study_success(self, mock_study_create, *args): ) self.assertEqual(study.id, expected_task_id) - @patch('mephisto.abstractions.providers.prolific.api.studies.Studies.create') + @patch("mephisto.abstractions.providers.prolific.api.studies.Studies.create") def test_create_study_error(self, mock_study_create, *args): project_id = uuid4().hex[:24] - exception_message = 'Error' + exception_message = "Error" mock_study_create.side_effect = ProlificRequestError(exception_message) with self.assertRaises(ProlificRequestError) as cm: create_study( diff --git a/test/core/test_live_runs.py b/test/core/test_live_runs.py index 0169370cf..97453498e 100644 --- a/test/core/test_live_runs.py +++ b/test/core/test_live_runs.py @@ -87,9 +87,7 @@ def setUp(self): self.provider.setup_resources_for_task_run( self.task_run, self.task_run.args, EMPTY_STATE, self.url ) - self.launcher = TaskLauncher( - self.db, self.task_run, self.get_mock_assignment_data_array() - ) + self.launcher = TaskLauncher(self.db, self.task_run, self.get_mock_assignment_data_array()) self.launcher.create_assignments() self.launcher.launch_units(self.url) self.client_io = ClientIOHandler(self.db) @@ -173,8 +171,7 @@ def assert_sandbox_worker_created(self, live_run, worker_name, timeout=2) -> Non self.assertTrue( # type: ignore self._run_loop_until( live_run, - lambda: len(self.db.find_workers(worker_name=worker_name + "_sandbox")) - > 0, + lambda: len(self.db.find_workers(worker_name=worker_name + "_sandbox")) > 0, timeout, ), f"Worker {worker_name} not created in time!", @@ -280,9 +277,7 @@ def test_register_concurrent_run(self): len(live_run.worker_pool.agents), 1, "Agent not registered with worker pool" ) - self.assertEqual( - len(task_runner.running_units), 1, "Ready task was not launched" - ) + self.assertEqual(len(task_runner.running_units), 1, "Ready task was not launched") # Register another worker mock_worker_name = "MOCK_WORKER_2" @@ -380,13 +375,9 @@ def test_register_run(self): self.assertEqual(len(agents), 1, "Agent may have been duplicated") agent = agents[0] self.assertIsNotNone(agent) - self.assertEqual( - len(self.worker_pool.agents), 1, "Agent not registered with worker pool" - ) + self.assertEqual(len(self.worker_pool.agents), 1, "Agent not registered with worker pool") - self.assertEqual( - len(task_runner.running_assignments), 0, "Task was not yet ready" - ) + self.assertEqual(len(task_runner.running_assignments), 0, "Task was not yet ready") # Register another worker mock_worker_name = "MOCK_WORKER_2" @@ -396,9 +387,7 @@ def test_register_run(self): self.architect.server.register_mock_agent(mock_worker_name, mock_agent_details) self.await_channel_requests(live_run) - self.assertEqual( - len(task_runner.running_assignments), 1, "Task was not launched" - ) + self.assertEqual(len(task_runner.running_assignments), 1, "Task was not launched") agents = [a for a in self.worker_pool.agents.values()] # Make both agents act @@ -488,20 +477,14 @@ def test_register_concurrent_run_with_onboarding(self): self.assertEqual(len(workers), 1, "Worker not successfully registered") worker_0 = workers[0] agents = self.db.find_agents() - self.assertEqual( - len(agents), 0, "Agent should not be created yet - need onboarding" - ) + self.assertEqual(len(agents), 0, "Agent should not be created yet - need onboarding") onboard_agents = self.db.find_onboarding_agents() - self.assertEqual( - len(onboard_agents), 1, "Onboarding agent should have been created" - ) + self.assertEqual(len(onboard_agents), 1, "Onboarding agent should have been created") last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) if not last_packet["data"].get("status") == "onboarding": - self.assertIn( - "onboard_data", last_packet["data"], "Onboarding not triggered" - ) + self.assertIn("onboard_data", last_packet["data"], "Onboarding not triggered") self.architect.server.last_packet = None # Submit onboarding from the agent @@ -539,9 +522,7 @@ def test_register_concurrent_run_with_onboarding(self): self.architect.server.register_mock_agent(mock_worker_name, mock_agent_details) self.await_channel_requests(live_run) agents = self.db.find_agents() - self.assertEqual( - len(agents), 0, "Agent should not be created yet, failed onboarding" - ) + self.assertEqual(len(agents), 0, "Agent should not be created yet, failed onboarding") last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) @@ -550,9 +531,7 @@ def test_register_concurrent_run_with_onboarding(self): last_packet["data"], "Onboarding triggered for disqualified worker", ) - self.assertIsNone( - last_packet["data"]["agent_id"], "worker assigned real agent id" - ) + self.assertIsNone(last_packet["data"]["agent_id"], "worker assigned real agent id") self.architect.server.last_packet = None self.db.revoke_qualification(qualification_id, worker_1.db_id) @@ -561,20 +540,14 @@ def test_register_concurrent_run_with_onboarding(self): self.architect.server.register_mock_agent(mock_worker_name, mock_agent_details) self.await_channel_requests(live_run) agents = self.db.find_agents() - self.assertEqual( - len(agents), 0, "Agent should not be created yet - need onboarding" - ) + self.assertEqual(len(agents), 0, "Agent should not be created yet - need onboarding") onboard_agents = self.db.find_onboarding_agents() - self.assertEqual( - len(onboard_agents), 2, "Onboarding agent should have been created" - ) + self.assertEqual(len(onboard_agents), 2, "Onboarding agent should have been created") last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) if not last_packet["data"].get("status") == "onboarding": - self.assertIn( - "onboard_data", last_packet["data"], "Onboarding not triggered" - ) + self.assertIn("onboard_data", last_packet["data"], "Onboarding not triggered") self.architect.server.last_packet = None # Submit onboarding from the agent @@ -593,9 +566,7 @@ def test_register_concurrent_run_with_onboarding(self): self.assertEqual(len(agents), 1, "Agent may have been duplicated") agent = agents[0] self.assertIsNotNone(agent) - self.assertEqual( - len(self.worker_pool.agents), 1, "Agent not registered with worker pool" - ) + self.assertEqual(len(self.worker_pool.agents), 1, "Agent not registered with worker pool") self.assertEqual( len(task_runner.running_assignments), @@ -623,9 +594,7 @@ def test_register_concurrent_run_with_onboarding(self): agents = self.db.find_agents() self.assertEqual(len(agents), 2, "Second agent not created without onboarding") - self.assertEqual( - len(task_runner.running_assignments), 1, "Task was not launched" - ) + self.assertEqual(len(task_runner.running_assignments), 1, "Task was not launched") self.assertFalse(worker_0.is_qualified(TEST_QUALIFICATION_NAME)) self.assertTrue(worker_0.is_disqualified(TEST_QUALIFICATION_NAME)) @@ -723,9 +692,7 @@ def test_register_run_with_onboarding(self): self.architect.server.register_mock_agent(mock_worker_name, mock_agent_details) self.await_channel_requests(live_run) agents = self.db.find_agents() - self.assertEqual( - len(agents), 0, "Agent should not be created yet, failed onboarding" - ) + self.assertEqual(len(agents), 0, "Agent should not be created yet, failed onboarding") last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) @@ -734,9 +701,7 @@ def test_register_run_with_onboarding(self): last_packet["data"], "Onboarding triggered for disqualified worker", ) - self.assertIsNone( - last_packet["data"]["agent_id"], "worker assigned real agent id" - ) + self.assertIsNone(last_packet["data"]["agent_id"], "worker assigned real agent id") self.architect.server.last_packet = None self.db.revoke_qualification(qualification_id, worker_1.db_id) @@ -745,20 +710,14 @@ def test_register_run_with_onboarding(self): self.architect.server.register_mock_agent(mock_worker_name, mock_agent_details) self.await_channel_requests(live_run) agents = self.db.find_agents() - self.assertEqual( - len(agents), 0, "Agent should not be created yet - need onboarding" - ) + self.assertEqual(len(agents), 0, "Agent should not be created yet - need onboarding") onboard_agents = self.db.find_onboarding_agents() - self.assertEqual( - len(onboard_agents), 1, "Onboarding agent should have been created" - ) + self.assertEqual(len(onboard_agents), 1, "Onboarding agent should have been created") last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) if not last_packet["data"].get("status") == "onboarding": - self.assertIn( - "onboard_data", last_packet["data"], "Onboarding not triggered" - ) + self.assertIn("onboard_data", last_packet["data"], "Onboarding not triggered") self.architect.server.last_packet = None # Submit onboarding from the agent @@ -820,20 +779,14 @@ def test_register_run_with_onboarding(self): workers = self.db.find_workers(worker_name=mock_worker_name + "_sandbox") worker_3 = workers[0] agents = self.db.find_agents() - self.assertEqual( - len(agents), 1, "Agent should not be created yet - need onboarding" - ) + self.assertEqual(len(agents), 1, "Agent should not be created yet - need onboarding") onboard_agents = self.db.find_onboarding_agents() - self.assertEqual( - len(onboard_agents), 2, "Onboarding agent should have been created" - ) + self.assertEqual(len(onboard_agents), 2, "Onboarding agent should have been created") self._await_current_tasks(live_run, 2) last_packet = self.architect.server.last_packet self.assertIsNotNone(last_packet) if not last_packet["data"].get("status") == "onboarding": - self.assertIn( - "onboard_data", last_packet["data"], "Onboarding not triggered" - ) + self.assertIn("onboard_data", last_packet["data"], "Onboarding not triggered") self.architect.server.last_packet = None # Submit onboarding from the agent @@ -858,9 +811,7 @@ def test_register_run_with_onboarding(self): "Agent not registered to worker pool after onboarding", ) - self.assertEqual( - len(task_runner.running_units), 2, "Task not launched after onboarding" - ) + self.assertEqual(len(task_runner.running_units), 2, "Task not launched after onboarding") agents = [a for a in self.worker_pool.agents.values()] @@ -970,9 +921,7 @@ def screen_unit(unit): # Register a screening agent successfully mock_agent_details = "FAKE_ASSIGNMENT" - self.architect.server.register_mock_agent( - mock_worker_name_1, mock_agent_details - ) + self.architect.server.register_mock_agent(mock_worker_name_1, mock_agent_details) self.await_channel_requests(live_run) workers = self.db.find_workers(worker_name=mock_worker_name_1 + "_sandbox") self.assertEqual(len(workers), 1, "Worker not successfully registered") @@ -988,9 +937,7 @@ def screen_unit(unit): # Register a second screening agent successfully mock_agent_details = "FAKE_ASSIGNMENT2" - self.architect.server.register_mock_agent( - mock_worker_name_2, mock_agent_details - ) + self.architect.server.register_mock_agent(mock_worker_name_2, mock_agent_details) self.await_channel_requests(live_run) workers = self.db.find_workers(worker_name=mock_worker_name_2 + "_sandbox") worker_2 = workers[0] @@ -1006,9 +953,7 @@ def screen_unit(unit): # Fail to register a third screening agent mock_agent_details = "FAKE_ASSIGNMENT3" - self.architect.server.register_mock_agent( - mock_worker_name_3, mock_agent_details - ) + self.architect.server.register_mock_agent(mock_worker_name_3, mock_agent_details) self.await_channel_requests(live_run) workers = self.db.find_workers(worker_name=mock_worker_name_3 + "_sandbox") worker_3 = workers[0] @@ -1020,9 +965,7 @@ def screen_unit(unit): # Register third screening agent mock_agent_details = "FAKE_ASSIGNMENT3" - self.architect.server.register_mock_agent( - mock_worker_name_3, mock_agent_details - ) + self.architect.server.register_mock_agent(mock_worker_name_3, mock_agent_details) self.await_channel_requests(live_run) agents = self.db.find_agents() self.assertEqual(len(agents), 3, "Third agent not created") @@ -1066,9 +1009,7 @@ def screen_unit(unit): # Accept a real task, and complete it, from worker 3 # Register a task agent successfully mock_agent_details = "FAKE_ASSIGNMENT4" - self.architect.server.register_mock_agent( - mock_worker_name_3, mock_agent_details - ) + self.architect.server.register_mock_agent(mock_worker_name_3, mock_agent_details) self.await_channel_requests(live_run) agents = self.db.find_agents() self.assertEqual(len(agents), 4, "No agent created for task") diff --git a/test/core/test_operator.py b/test/core/test_operator.py index 7686b9b55..8478a6df8 100644 --- a/test/core/test_operator.py +++ b/test/core/test_operator.py @@ -69,15 +69,11 @@ def tearDown(self): shutil.rmtree(self.data_dir, ignore_errors=True) SHUTDOWN_TIMEOUT = 10 threads = threading.enumerate() - target_threads = [ - t for t in threads if not isinstance(t, TMonitor) and not t.daemon - ] + target_threads = [t for t in threads if not isinstance(t, TMonitor) and not t.daemon] start_time = time.time() while len(target_threads) > 1 and time.time() - start_time < SHUTDOWN_TIMEOUT: threads = threading.enumerate() - target_threads = [ - t for t in threads if not isinstance(t, TMonitor) and not t.daemon - ] + target_threads = [t for t in threads if not isinstance(t, TMonitor) and not t.daemon] time.sleep(0.3) self.assertTrue( time.time() - start_time < SHUTDOWN_TIMEOUT, @@ -110,8 +106,7 @@ def test_initialize_operator(self): def assert_sandbox_worker_created(self, worker_name, timeout=2) -> None: self.assertTrue( # type: ignore self.operator._run_loop_until( - lambda: len(self.db.find_workers(worker_name=worker_name + "_sandbox")) - > 0, + lambda: len(self.db.find_workers(worker_name=worker_name + "_sandbox")) > 0, timeout, ), f"Worker {worker_name} not created in time!", @@ -190,9 +185,7 @@ def test_run_job_concurrent(self): # Give up to 5 seconds for whole mock task to complete start_time = time.time() self.operator._wait_for_runs_in_testing(TIMEOUT_TIME) - self.assertLess( - time.time() - start_time, TIMEOUT_TIME, "Task not completed in time" - ) + self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Task not completed in time") # Ensure the assignment is completed task_run = tracked_run.task_run @@ -252,9 +245,7 @@ def test_run_job_not_concurrent(self): # Give up to 5 seconds for both tasks to complete start_time = time.time() self.operator._wait_for_runs_in_testing(TIMEOUT_TIME) - self.assertLess( - time.time() - start_time, TIMEOUT_TIME, "Task not completed in time" - ) + self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Task not completed in time") # Ensure the assignment is completed task_run = tracked_run.task_run @@ -294,9 +285,7 @@ def test_patience_shutdown(self): # Give a few seconds for the operator to shutdown start_time = time.time() self.operator._wait_for_runs_in_testing(TIMEOUT_TIME) - self.assertLess( - time.time() - start_time, TIMEOUT_TIME, "Task shutdown not enacted in time" - ) + self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Task shutdown not enacted in time") # Ensure the task run was forced to shut down task_run = tracked_run.task_run @@ -439,9 +428,7 @@ def test_run_jobs_with_restrictions(self): # Give up to 5 seconds for whole mock task to complete start_time = time.time() self.operator._wait_for_runs_in_testing(TIMEOUT_TIME) - self.assertLess( - time.time() - start_time, TIMEOUT_TIME, "Task not completed in time" - ) + self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Task not completed in time") self.operator.shutdown() # Create a new operator, shutdown is a one-time thing @@ -512,9 +499,7 @@ def test_run_jobs_with_restrictions(self): # Ensure the task run completed and that all assignments are done start_time = time.time() self.operator._wait_for_runs_in_testing(TIMEOUT_TIME) - self.assertLess( - time.time() - start_time, TIMEOUT_TIME, "Task not completed in time" - ) + self.assertLess(time.time() - start_time, TIMEOUT_TIME, "Task not completed in time") task_run = tracked_run.task_run assignments = task_run.get_assignments() for assignment in assignments: diff --git a/test/core/test_task_launcher.py b/test/core/test_task_launcher.py index 367d58536..eff95f528 100644 --- a/test/core/test_task_launcher.py +++ b/test/core/test_task_launcher.py @@ -77,9 +77,7 @@ def get_mock_assignment_data_generator() -> Iterable[InitializationData]: def test_init_on_task_run(self): """Initialize a launcher on a task_run""" - launcher = TaskLauncher( - self.db, self.task_run, self.get_mock_assignment_data_array() - ) + launcher = TaskLauncher(self.db, self.task_run, self.get_mock_assignment_data_array()) self.assertEqual(self.db, launcher.db) self.assertEqual(self.task_run, launcher.task_run) self.assertEqual(len(launcher.assignments), 0) @@ -139,9 +137,7 @@ def test_launch_assignments_with_concurrent_unit_cap(self): launcher.launch_units("dummy-url:3000") start_time = time.time() - while set([u.get_status() for u in launcher.units]) != { - AssignmentState.COMPLETED - }: + while set([u.get_status() for u in launcher.units]) != {AssignmentState.COMPLETED}: for unit in launcher.units: if unit.get_status() == AssignmentState.LAUNCHED: unit.set_db_status(AssignmentState.COMPLETED) diff --git a/test/test_data_model.py b/test/test_data_model.py index 5d86541cf..7363eab67 100644 --- a/test/test_data_model.py +++ b/test/test_data_model.py @@ -29,9 +29,7 @@ def test_ensure_valid_statuses(self): found_keys = [k for k in dir(a_state) if k.upper() == k] found_vals = [getattr(a_state, k) for k in found_keys] for v in found_vals: - self.assertIn( - v, found_valid, f"Expected to find {v} in valid list {found_valid}" - ) + self.assertIn(v, found_valid, f"Expected to find {v} in valid list {found_valid}") for sublist, found_array in SUBARRAYS.items(): for v in found_array: self.assertIn( diff --git a/test/tools/test_data_brower.py b/test/tools/test_data_brower.py index 081489b29..2e14ca5c1 100644 --- a/test/tools/test_data_brower.py +++ b/test/tools/test_data_brower.py @@ -79,9 +79,7 @@ def test_find_workers_by_quals(self) -> None: qualified_ids, f"Worker 3 not in qualified list, found {qualified_ids}", ) - self.assertNotIn( - worker_2.db_id, qualified_ids, "Worker 2 should not be in qualified list" - ) + self.assertNotIn(worker_2.db_id, qualified_ids, "Worker 2 should not be in qualified list") if __name__ == "__main__": diff --git a/test/utils/prolific_api/test_data_models.py b/test/utils/prolific_api/test_data_models.py index 3278921dc..394325def 100644 --- a/test/utils/prolific_api/test_data_models.py +++ b/test/utils/prolific_api/test_data_models.py @@ -16,31 +16,31 @@ class TestDataModelsUtils(unittest.TestCase): def test_study_validation_passed(self, *args): data = { - 'name': 'Name', - 'internal_name': 'internal_name', - 'description': 'Description', - 'external_study_url': 'https://url', - 'prolific_id_option': 'url_parameters', - 'completion_option': 'url', - 'completion_codes': [], - 'total_available_places': 100, - 'estimated_completion_time': 5, - 'reward': 999, - 'device_compatibility': ['desktop'], - 'peripheral_requirements': [], - 'eligibility_requirements': [], + "name": "Name", + "internal_name": "internal_name", + "description": "Description", + "external_study_url": "https://url", + "prolific_id_option": "url_parameters", + "completion_option": "url", + "completion_codes": [], + "total_available_places": 100, + "estimated_completion_time": 5, + "reward": 999, + "device_compatibility": ["desktop"], + "peripheral_requirements": [], + "eligibility_requirements": [], } study = Study(**data) - self.assertEqual(study.name, data['name']) + self.assertEqual(study.name, data["name"]) def test_study_validation_error(self, *args): data = { - 'name': 'Name', + "name": "Name", } with self.assertRaises(ValidationError) as cm: Study(**data).validate() - self.assertEqual(cm.exception.validator, 'required') + self.assertEqual(cm.exception.validator, "required") self.assertEqual(cm.exception.message, "'description' is a required property") From 4d2da488c7cd434563275652cc24f2ca7d43d782 Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Fri, 4 Aug 2023 16:39:38 -0400 Subject: [PATCH 2/3] Running prettier --- docker/docker-compose.dev.vscode.yml | 30 +++++++++---------- docker/docker-compose.dev.yml | 7 ++--- .../hydra_configs/conf/prolific_example.yaml | 28 ++++++++--------- 3 files changed, 32 insertions(+), 33 deletions(-) diff --git a/docker/docker-compose.dev.vscode.yml b/docker/docker-compose.dev.vscode.yml index 626cd4000..f228332f7 100644 --- a/docker/docker-compose.dev.vscode.yml +++ b/docker/docker-compose.dev.vscode.yml @@ -1,4 +1,4 @@ -version: '3' +version: "3" services: fb_mephisto_vscode: @@ -7,9 +7,9 @@ services: context: .. dockerfile: Dockerfile ports: - - '8081:8000' - - '3001:3000' - - '5678:5678' + - "8081:8000" + - "3001:3000" + - "5678:5678" volumes: - ..:/mephisto - ./entrypoints/server.sh:/entrypoint.sh @@ -18,14 +18,14 @@ services: entrypoint: /entrypoint.sh env_file: envs/env.local command: [ - "sh", - "-c", - "pip install debugpy -t /tmp - && - python - /tmp/debugpy - --wait-for-client - --listen 0.0.0.0:5678 - /mephisto/examples/simple_static_task/static_test_script.py - " - ] + "sh", + "-c", + "pip install debugpy -t /tmp + && + python + /tmp/debugpy + --wait-for-client + --listen 0.0.0.0:5678 + /mephisto/examples/simple_static_task/static_test_script.py + ", + ] diff --git a/docker/docker-compose.dev.yml b/docker/docker-compose.dev.yml index fdf521006..efee4dc4c 100644 --- a/docker/docker-compose.dev.yml +++ b/docker/docker-compose.dev.yml @@ -1,4 +1,4 @@ -version: '3' +version: "3" services: fb_mephisto: @@ -7,8 +7,8 @@ services: context: .. dockerfile: Dockerfile ports: - - '8081:8000' - - '3001:3000' + - "8081:8000" + - "3001:3000" volumes: - ..:/mephisto - ./entrypoints/server.prolific.sh:/entrypoint.sh @@ -17,4 +17,3 @@ services: entrypoint: /entrypoint.sh env_file: envs/env.dev command: tail -f /dev/null - diff --git a/examples/simple_static_task/hydra_configs/conf/prolific_example.yaml b/examples/simple_static_task/hydra_configs/conf/prolific_example.yaml index edfa8489a..4bdd2a265 100644 --- a/examples/simple_static_task/hydra_configs/conf/prolific_example.yaml +++ b/examples/simple_static_task/hydra_configs/conf/prolific_example.yaml @@ -7,36 +7,36 @@ mephisto: architect: _architect_type: ec2 profile_name: mephisto-router-iam - subdomain: '0802.1' + subdomain: "0802.1" blueprint: data_csv: ${task_dir}/data_prolific.csv task_source: ${task_dir}/server_files/demo_task.html preview_source: ${task_dir}/server_files/demo_preview.html extra_source_dir: ${task_dir}/server_files/extra_refs units_per_assignment: 2 - log_level: 'debug' + log_level: "debug" task: - task_name: '0802' - task_title: '0802 Task' - task_description: 'This is a simple test of static Prolific tasks.' + task_name: "0802" + task_title: "0802 Task" + task_description: "This is a simple test of static Prolific tasks." task_reward: 70 - task_tags: 'static,task,testing' + task_tags: "static,task,testing" max_num_concurrent_units: 1 provider: - prolific_external_study_url: 'https://example.com?participant_id={{%PROLIFIC_PID%}}&study_id={{%STUDY_ID%}}&submission_id={{%SESSION_ID%}}' - prolific_id_option: 'url_parameters' - prolific_workspace_name: 'My Workspace' - prolific_project_name: 'Project' - prolific_allow_list_group_name: 'Allow list' - prolific_block_list_group_name: 'Block list' + prolific_external_study_url: "https://example.com?participant_id={{%PROLIFIC_PID%}}&study_id={{%STUDY_ID%}}&submission_id={{%SESSION_ID%}}" + prolific_id_option: "url_parameters" + prolific_workspace_name: "My Workspace" + prolific_project_name: "Project" + prolific_allow_list_group_name: "Allow list" + prolific_block_list_group_name: "Block list" prolific_eligibility_requirements: - - name: 'CustomWhitelistEligibilityRequirement' + - name: "CustomWhitelistEligibilityRequirement" white_list: - 6463d32f50a18041930b71be - 6463d3922d7d99360896228f - 6463d40e8d5d2f0cce2b3b23 - 6463d44ed1b61a8fb4e0765a - 6463d488c2f2821eaa2fa13f - - name: 'ApprovalRateEligibilityRequirement' + - name: "ApprovalRateEligibilityRequirement" minimum_approval_rate: 0 maximum_approval_rate: 100 From a925e3f6c3af8af586af868f8de51956b2865134 Mon Sep 17 00:00:00 2001 From: Jack Urbanek Date: Tue, 15 Aug 2023 10:56:04 -0400 Subject: [PATCH 3/3] missing file --- mephisto/configs/logging.py | 39 +++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/mephisto/configs/logging.py b/mephisto/configs/logging.py index 30e5da4c1..b7cfb467f 100644 --- a/mephisto/configs/logging.py +++ b/mephisto/configs/logging.py @@ -10,7 +10,7 @@ BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") -WRITE_LOG_TO_FILE = os.environ.get("WRITE_LOG_TO_FILE", '0') +WRITE_LOG_TO_FILE = os.environ.get("WRITE_LOG_TO_FILE", "0") _now = datetime.now() date_string = _now.strftime("%Y-%m-%d") @@ -18,25 +18,26 @@ def get_log_handlers(): - """ We enable module-level loggers via env variable (that we can set in the console), - so that hydra doesn't create an empty file for every module-level logger + """We enable module-level loggers via env variable (that we can set in the console), + so that hydra doesn't create an empty file for every module-level logger """ handlers = ["console"] - if WRITE_LOG_TO_FILE == '1': + if WRITE_LOG_TO_FILE == "1": handlers.append("file") # Create dirs recursivelly if they do not exist - os.makedirs( - os.path.join(BASE_DIR, "outputs", date_string, time_string), - exist_ok=True - ) + os.makedirs(os.path.join(BASE_DIR, "outputs", date_string, time_string), exist_ok=True) return handlers def get_log_filename(): - """ Compose logfile path formatted same way as hydra """ + """Compose logfile path formatted same way as hydra""" executed_filename = os.path.splitext(os.path.basename(sys.argv[0]))[0] return os.path.join( - BASE_DIR, "outputs", date_string, time_string, f'{executed_filename}.log', + BASE_DIR, + "outputs", + date_string, + time_string, + f"{executed_filename}.log", ) @@ -58,14 +59,18 @@ def get_log_filename(): "class": "logging.StreamHandler", "formatter": "default", }, - **({ - "file": { - "level": LOG_LEVEL, - "class": "logging.FileHandler", - "filename": log_filename, - "formatter": "default", + **( + { + "file": { + "level": LOG_LEVEL, + "class": "logging.FileHandler", + "filename": log_filename, + "formatter": "default", + } } - } if "file" in log_handlers else {}), + if "file" in log_handlers + else {} + ), }, "loggers": { "": {