diff --git a/b2/console_tool.py b/b2/console_tool.py index 7774ad053..4568c54ee 100644 --- a/b2/console_tool.py +++ b/b2/console_tool.py @@ -9,6 +9,7 @@ # License https://www.backblaze.com/using_b2_code.html # ###################################################################### +from __future__ import annotations import argparse import base64 @@ -37,7 +38,7 @@ from concurrent.futures import Executor, Future, ThreadPoolExecutor from contextlib import suppress from enum import Enum -from typing import Any, BinaryIO, Dict, List, Optional, Tuple +from typing import Any, BinaryIO, List import argcomplete import b2sdk @@ -393,6 +394,69 @@ def _get_file_retention_setting(cls, args): return FileRetentionSetting(file_retention_mode, args.retainUntil) +class HeaderFlagsMixin(Described): + @classmethod + def _setup_parser(cls, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + '--cache-control', + help= + "optional Cache-Control header, value based on RFC 2616 section 14.9, example: 'public, max-age=86400')" + ) + parser.add_argument( + '--content-disposition', + help= + "optional Content-Disposition header, value based on RFC 2616 section 19.5.1, example: 'attachment; filename=\"fname.ext\"'" + ) + parser.add_argument( + '--content-encoding', + help= + "optional Content-Encoding header, value based on RFC 2616 section 14.11, example: 'gzip'" + ) + parser.add_argument( + '--content-language', + help= + "optional Content-Language header, value based on RFC 2616 section 14.12, example: 'mi, en'" + ) + parser.add_argument( + '--expires', + help= + "optional Expires header, value based on RFC 2616 section 14.21, example: 'Thu, 01 Dec 2050 16:00:00 GMT'" + ) + super()._setup_parser(parser) + + def _file_info_with_header_args(self, args, + file_info: dict[str, str] | None) -> dict[str, str] | None: + """Construct an updated file_info dictionary. + Print a warning if any of file_info items will be overwritten by explicit header arguments. + """ + add_file_info = {} + overwritten = [] + if args.cache_control is not None: + add_file_info['b2-cache-control'] = args.cache_control + if args.content_disposition is not None: + add_file_info['b2-content-disposition'] = args.content_disposition + if args.content_encoding is not None: + add_file_info['b2-content-encoding'] = args.content_encoding + if args.content_language is not None: + add_file_info['b2-content-language'] = args.content_language + if args.expires is not None: + add_file_info['b2-expires'] = args.expires + + for key, value in add_file_info.items(): + if file_info is not None and key in file_info and file_info[key] != value: + overwritten.append(key) + + if overwritten: + self._print_stderr( + 'The following file info items will be overwritten by explicit arguments:\n ' + + '\n '.join(f'{key} = {add_file_info[key]}' for key in overwritten) + ) + + if add_file_info: + return {**(file_info or {}), **add_file_info} + return file_info + + class LegalHoldMixin(Described): """ Setting legal holds requires the **writeFileLegalHolds** capability, and only works in bucket @@ -696,12 +760,12 @@ def _print_json(self, data) -> None: json.dumps(data, indent=4, sort_keys=True, cls=B2CliJsonEncoder), enforce_output=True ) - def _print(self, *args, enforce_output: bool = False, end: Optional[str] = None) -> None: + def _print(self, *args, enforce_output: bool = False, end: str | None = None) -> None: return self._print_standard_descriptor( self.stdout, "stdout", *args, enforce_output=enforce_output, end=end ) - def _print_stderr(self, *args, end: Optional[str] = None) -> None: + def _print_stderr(self, *args, end: str | None = None) -> None: return self._print_standard_descriptor( self.stderr, "stderr", *args, enforce_output=True, end=end ) @@ -712,7 +776,7 @@ def _print_standard_descriptor( descriptor_name: str, *args, enforce_output: bool = False, - end: Optional[str] = None, + end: str | None = None, ) -> None: """ Prints to fd, unless quiet is set. @@ -733,7 +797,7 @@ def _print_helper( descriptor_encoding: str, descriptor_name: str, *args, - end: Optional[str] = None + end: str | None = None ): try: descriptor.write(' '.join(args)) @@ -1010,7 +1074,8 @@ def run(self, args): @B2.register_subcommand class CopyFileById( - DestinationSseMixin, SourceSseMixin, FileRetentionSettingMixin, LegalHoldMixin, Command + HeaderFlagsMixin, DestinationSseMixin, SourceSseMixin, FileRetentionSettingMixin, + LegalHoldMixin, Command ): """ Copy a file version to the given bucket (server-side, **not** via download+upload). @@ -1072,6 +1137,7 @@ def run(self, args): file_infos = self._parse_file_infos(args.info) elif args.noInfo: file_infos = {} + file_infos = self._file_info_with_header_args(args, file_infos) if args.metadataDirective is not None: self._print_stderr( @@ -1115,7 +1181,7 @@ def run(self, args): self._print_json(file_version) return 0 - def _is_ssec(self, encryption: Optional[EncryptionSetting]): + def _is_ssec(self, encryption: EncryptionSetting | None): if encryption is not None and encryption.mode == EncryptionMode.SSE_C: return True return False @@ -1123,12 +1189,12 @@ def _is_ssec(self, encryption: Optional[EncryptionSetting]): def _determine_source_metadata( self, source_file_id: str, - destination_encryption: Optional[EncryptionSetting], - source_encryption: Optional[EncryptionSetting], - target_file_info: Optional[dict], - target_content_type: Optional[str], + destination_encryption: EncryptionSetting | None, + source_encryption: EncryptionSetting | None, + target_file_info: dict | None, + target_content_type: str | None, fetch_if_necessary: bool, - ) -> Tuple[Optional[dict], Optional[str]]: + ) -> tuple[dict | None, str | None]: """Determine if source file metadata is necessary to perform the copy - due to sse_c_key_id""" if not self._is_ssec(source_encryption) and not self._is_ssec( destination_encryption @@ -1361,6 +1427,8 @@ def _print_download_info(self, downloaded_file: DownloadedFile): 'Legal hold', self._represent_legal_hold(download_version.legal_hold) ) for label, attr_name in [ + ('CacheControl', 'cache_control'), + ('Expires', 'expires'), ('ContentDisposition', 'content_disposition'), ('ContentLanguage', 'content_language'), ('ContentEncoding', 'content_encoding'), @@ -1997,7 +2065,7 @@ def _print_file_version( self, args, file_version: FileVersion, - folder_name: Optional[str], + folder_name: str | None, ) -> None: self._print(folder_name or file_version.file_name) @@ -2104,7 +2172,7 @@ def _print_file_version( self, args, file_version: FileVersion, - folder_name: Optional[str], + folder_name: str | None, ) -> None: if not args.long: super()._print_file_version(args, file_version, folder_name) @@ -2226,7 +2294,7 @@ class SubmitThread(threading.Thread): def __init__( self, - runner: 'Rm', + runner: Rm, args: argparse.Namespace, messages_queue: queue.Queue, reporter: ProgressReport, @@ -2860,6 +2928,7 @@ def _setup_parser(cls, parser): class UploadFileMixin( + HeaderFlagsMixin, MinPartSizeMixin, ThreadsMixin, ProgressMixin, @@ -2888,7 +2957,6 @@ def _setup_parser(cls, parser): parser.add_argument( '--sha1', help="SHA-1 of the data being uploaded for verifying file integrity" ) - parser.add_argument('--cache-control', default=None) parser.add_argument( '--info', action='append', @@ -2935,11 +3003,11 @@ def get_execute_kwargs(self, args) -> dict: else: file_infos[SRC_LAST_MODIFIED_MILLIS] = str(int(mtime * 1000)) + file_infos = self._file_info_with_header_args(args, file_infos) + return { "bucket": self.api.get_bucket_by_name(args.bucketName), - "cache_control": - args.cache_control, "content_type": args.contentType, "custom_upload_timestamp": @@ -2967,7 +3035,7 @@ def get_execute_kwargs(self, args) -> dict: } @abstractmethod - def execute_operation(self, **kwargs) -> 'b2sdk.file_version.FileVersion': + def execute_operation(self, **kwargs) -> b2sdk.file_version.FileVersion: raise NotImplementedError def upload_file_kwargs_to_unbound_upload(self, **kwargs): @@ -2979,7 +3047,7 @@ def upload_file_kwargs_to_unbound_upload(self, **kwargs): kwargs["read_size"] = kwargs["min_part_size"] or DEFAULT_MIN_PART_SIZE return kwargs - def get_input_stream(self, filename: str) -> 'str | int | io.BinaryIO': + def get_input_stream(self, filename: str) -> str | int | io.BinaryIO: """Get input stream IF filename points to a FIFO or stdin.""" if filename == "-": if os.path.exists('-'): @@ -2993,9 +3061,7 @@ def get_input_stream(self, filename: str) -> 'str | int | io.BinaryIO': raise self.NotAnInputStream() - def file_identifier_to_read_stream( - self, file_id: 'str | int | BinaryIO', buffering - ) -> BinaryIO: + def file_identifier_to_read_stream(self, file_id: str | int | BinaryIO, buffering) -> BinaryIO: if isinstance(file_id, (str, int)): return open( file_id, @@ -3319,7 +3385,7 @@ def run(self, args): return 0 @classmethod - def alter_rule_by_name(cls, bucket: Bucket, name: str) -> Tuple[bool, bool]: + def alter_rule_by_name(cls, bucket: Bucket, name: str) -> tuple[bool, bool]: """ returns False if rule could not be found """ if not bucket.replication or not bucket.replication.rules: return False, False @@ -3356,7 +3422,7 @@ def alter_rule_by_name(cls, bucket: Bucket, name: str) -> Tuple[bool, bool]: @classmethod @abstractmethod - def alter_one_rule(cls, rule: ReplicationRule) -> Optional[ReplicationRule]: + def alter_one_rule(cls, rule: ReplicationRule) -> ReplicationRule | None: """ return None to delete a rule """ pass @@ -3373,7 +3439,7 @@ class ReplicationDelete(ReplicationRuleChanger): """ @classmethod - def alter_one_rule(cls, rule: ReplicationRule) -> Optional[ReplicationRule]: + def alter_one_rule(cls, rule: ReplicationRule) -> ReplicationRule | None: """ return None to delete rule """ return None @@ -3390,7 +3456,7 @@ class ReplicationPause(ReplicationRuleChanger): """ @classmethod - def alter_one_rule(cls, rule: ReplicationRule) -> Optional[ReplicationRule]: + def alter_one_rule(cls, rule: ReplicationRule) -> ReplicationRule | None: """ return None to delete rule """ rule.is_enabled = False return rule @@ -3408,7 +3474,7 @@ class ReplicationUnpause(ReplicationRuleChanger): """ @classmethod - def alter_one_rule(cls, rule: ReplicationRule) -> Optional[ReplicationRule]: + def alter_one_rule(cls, rule: ReplicationRule) -> ReplicationRule | None: """ return None to delete rule """ rule.is_enabled = True return rule @@ -3508,9 +3574,9 @@ def run(self, args): @classmethod def get_results_for_rule( - cls, bucket: Bucket, rule: ReplicationRule, destination_api: Optional[B2Api], + cls, bucket: Bucket, rule: ReplicationRule, destination_api: B2Api | None, scan_destination: bool, quiet: bool - ) -> List[dict]: + ) -> list[dict]: monitor = ReplicationMonitor( bucket=bucket, rule=rule, @@ -3527,7 +3593,7 @@ def get_results_for_rule( ] @classmethod - def filter_results_columns(cls, results: List[dict], columns: List[str]) -> List[dict]: + def filter_results_columns(cls, results: list[dict], columns: list[str]) -> list[dict]: return [{key: result[key] for key in columns} for result in results] @classmethod @@ -3543,10 +3609,10 @@ def to_human_readable(cls, value: Any) -> str: return str(value) - def output_json(self, results: Dict[str, List[dict]]) -> None: + def output_json(self, results: dict[str, list[dict]]) -> None: self._print_json(results) - def output_console(self, results: Dict[str, List[dict]]) -> None: + def output_console(self, results: dict[str, list[dict]]) -> None: for rule_name, rule_results in results.items(): self._print(f'Replication "{rule_name}":') rule_results = [ @@ -3558,7 +3624,7 @@ def output_console(self, results: Dict[str, List[dict]]) -> None: ] self._print(tabulate(rule_results, headers='keys', tablefmt='grid')) - def output_csv(self, results: Dict[str, List[dict]]) -> None: + def output_csv(self, results: dict[str, list[dict]]) -> None: rows = [] @@ -3735,7 +3801,7 @@ def _put_license_text_for_packages(self, stream: io.StringIO): stream.write(str(summary_table)) @classmethod - def _get_licenses_dicts(cls) -> List[Dict]: + def _get_licenses_dicts(cls) -> list[dict]: assert piplicenses, 'In order to run this command, you need to install the `license` extra: pip install b2[license]' pipdeptree_run = subprocess.run( ["pipdeptree", "--json", "-p", "b2"], @@ -3842,7 +3908,7 @@ class ConsoleTool: Uses a ``b2sdk.SqlitedAccountInfo`` object to keep account data between runs. """ - def __init__(self, b2_api: Optional[B2Api], stdout, stderr): + def __init__(self, b2_api: B2Api | None, stdout, stderr): self.api = b2_api self.stdout = stdout self.stderr = stderr diff --git a/changelog.d/+add_header_options.added.md b/changelog.d/+add_header_options.added.md new file mode 100644 index 000000000..24b46d66f --- /dev/null +++ b/changelog.d/+add_header_options.added.md @@ -0,0 +1 @@ +Add `--expires`, `--content-disposition`, `--content-encoding`, `--content-language` options to subcommands `upload-file`, `upload-unbound-stream`, `copy-file-by-id` \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 68b52fc05..79f6c88b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ argcomplete>=2,<4 arrow>=1.0.2,<2.0.0 -b2sdk>=1.25.0,<2 +b2sdk>=1.26.0,<2 docutils>=0.18.1 idna~=3.4; platform_system == 'Java' importlib-metadata~=3.3; python_version < '3.8' diff --git a/test/integration/test_b2_command_line.py b/test/integration/test_b2_command_line.py index 2cb24c8f7..e62e359d3 100755 --- a/test/integration/test_b2_command_line.py +++ b/test/integration/test_b2_command_line.py @@ -2690,3 +2690,64 @@ def test_cat(b2_tool, bucket_name, sample_filepath, tmp_path, uploaded_sample_fi ).replace("\r", "") == sample_filepath.read_text() assert b2_tool.should_succeed(['cat', f"b2id://{uploaded_sample_file['fileId']}" ],).replace("\r", "") == sample_filepath.read_text() + + +def test_header_arguments(b2_tool, bucket_name, sample_filepath, tmp_path): + # yapf: disable + args = [ + '--cache-control', 'max-age=3600', + '--content-disposition', 'attachment', + '--content-encoding', 'gzip', + '--content-language', 'en', + '--expires', 'Thu, 01 Dec 2050 16:00:00 GMT', + ] + # yapf: enable + expected_file_info = { + 'b2-cache-control': 'max-age=3600', + 'b2-content-disposition': 'attachment', + 'b2-content-encoding': 'gzip', + 'b2-content-language': 'en', + 'b2-expires': 'Thu, 01 Dec 2050 16:00:00 GMT', + } + + def assert_expected(file_info, expected=expected_file_info): + for key, val in expected.items(): + assert file_info[key] == val + + status, stdout, stderr = b2_tool.execute( + [ + 'upload-file', + '--quiet', + '--noProgress', + bucket_name, + str(sample_filepath), + 'sample_file', + *args, + '--info', + 'b2-content-disposition=will-be-overwritten', + ] + ) + assert status == 0 + file_version = json.loads(stdout) + assert_expected(file_version['fileInfo']) + + # Since we used both --info and --content-disposition to set b2-content-disposition, + # a warning should be emitted + assert 'will be overwritten' in stderr and 'b2-content-disposition = attachment' in stderr + + copied_version = b2_tool.should_succeed_json( + [ + 'copy-file-by-id', '--quiet', *args, '--contentType', 'text/plain', + file_version['fileId'], bucket_name, 'copied_file' + ] + ) + assert_expected(copied_version['fileInfo']) + + download_output = b2_tool.should_succeed( + ['download-file-by-id', file_version['fileId'], tmp_path / 'downloaded_file'] + ) + assert re.search(r'CacheControl: *max-age=3600', download_output) + assert re.search(r'ContentDisposition: *attachment', download_output) + assert re.search(r'ContentEncoding: *gzip', download_output) + assert re.search(r'ContentLanguage: *en', download_output) + assert re.search(r'Expires: *Thu, 01 Dec 2050 16:00:00 GMT', download_output) diff --git a/test/unit/console_tool/test_upload_file.py b/test/unit/console_tool/test_upload_file.py index c9da4d322..095b1a6bc 100644 --- a/test/unit/console_tool/test_upload_file.py +++ b/test/unit/console_tool/test_upload_file.py @@ -13,7 +13,7 @@ import b2 -def test_upload_file__file_info_src_last_modified_millis(b2_cli, bucket, tmpdir): +def test_upload_file__file_info_src_last_modified_millis_and_headers(b2_cli, bucket, tmpdir): """Test upload_file supports manually specifying file info src_last_modified_millis""" filename = 'file1.txt' content = 'hello world' @@ -23,15 +23,24 @@ def test_upload_file__file_info_src_last_modified_millis(b2_cli, bucket, tmpdir) expected_json = { "action": "upload", "contentSha1": "2aae6c35c94fcfb415dbe95f408b9ce91ee846ed", - "fileInfo": { - "src_last_modified_millis": "1" - }, + "fileInfo": + { + "b2-cache-control": "max-age=3600", + "b2-expires": "Thu, 01 Dec 2050 16:00:00 GMT", + "b2-content-language": "en", + "b2-content-disposition": "attachment", + "b2-content-encoding": "gzip", + "src_last_modified_millis": "1" + }, "fileName": filename, "size": len(content), } b2_cli.run( [ 'upload-file', '--noProgress', '--info=src_last_modified_millis=1', 'my-bucket', + '--cache-control', 'max-age=3600', '--expires', 'Thu, 01 Dec 2050 16:00:00 GMT', + '--content-language', 'en', '--content-disposition', 'attachment', '--content-encoding', + 'gzip', str(local_file1), 'file1.txt' ], expected_json_in_stdout=expected_json,