Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CLI arguments to set Expire, Content-Language, Content-Disposition, Content-Encoding #961

Merged
merged 12 commits into from
Nov 22, 2023
Merged
138 changes: 102 additions & 36 deletions b2/console_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# License https://www.backblaze.com/using_b2_code.html
#
######################################################################
from __future__ import annotations

import argparse
import base64
Expand Down Expand Up @@ -37,7 +38,7 @@
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import suppress
from enum import Enum
from typing import Any, BinaryIO, Dict, List, Optional, Tuple
from typing import Any, BinaryIO, List

import argcomplete
import b2sdk
Expand Down Expand Up @@ -393,6 +394,69 @@ def _get_file_retention_setting(cls, args):
return FileRetentionSetting(file_retention_mode, args.retainUntil)


class HeaderFlagsMixin(Described):
@classmethod
def _setup_parser(cls, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
'--cache-control',
help=
"optional Cache-Control header, value based on RFC 2616 section 14.9, example: 'public, max-age=86400')"
)
parser.add_argument(
'--content-disposition',
help=
"optional Content-Disposition header, value based on RFC 2616 section 19.5.1, example: 'attachment; filename=\"fname.ext\"'"
)
parser.add_argument(
'--content-encoding',
help=
"optional Content-Encoding header, value based on RFC 2616 section 14.11, example: 'gzip'"
)
parser.add_argument(
'--content-language',
help=
"optional Content-Language header, value based on RFC 2616 section 14.12, example: 'mi, en'"
)
parser.add_argument(
'--expires',
help=
"optional Expires header, value based on RFC 2616 section 14.21, example: 'Thu, 01 Dec 2050 16:00:00 GMT'"
)
super()._setup_parser(parser)

def _file_info_with_header_args(self, args,
file_info: dict[str, str] | None) -> dict[str, str] | None:
"""Construct an updated file_info dictionary.
Print a warning if any of file_info items will be overwritten by explicit header arguments.
"""
add_file_info = {}
overwritten = []
if args.cache_control is not None:
add_file_info['b2-cache-control'] = args.cache_control
if args.content_disposition is not None:
add_file_info['b2-content-disposition'] = args.content_disposition
if args.content_encoding is not None:
add_file_info['b2-content-encoding'] = args.content_encoding
if args.content_language is not None:
add_file_info['b2-content-language'] = args.content_language
if args.expires is not None:
add_file_info['b2-expires'] = args.expires

for key, value in add_file_info.items():
if file_info is not None and key in file_info and file_info[key] != value:
overwritten.append(key)

if overwritten:
self._print_stderr(
'The following file info items will be overwritten by explicit arguments:\n ' +
'\n '.join(f'{key} = {add_file_info[key]}' for key in overwritten)
)

if add_file_info:
return {**(file_info or {}), **add_file_info}
return file_info


class LegalHoldMixin(Described):
"""
Setting legal holds requires the **writeFileLegalHolds** capability, and only works in bucket
Expand Down Expand Up @@ -696,12 +760,12 @@ def _print_json(self, data) -> None:
json.dumps(data, indent=4, sort_keys=True, cls=B2CliJsonEncoder), enforce_output=True
)

def _print(self, *args, enforce_output: bool = False, end: Optional[str] = None) -> None:
def _print(self, *args, enforce_output: bool = False, end: str | None = None) -> None:
return self._print_standard_descriptor(
self.stdout, "stdout", *args, enforce_output=enforce_output, end=end
)

def _print_stderr(self, *args, end: Optional[str] = None) -> None:
def _print_stderr(self, *args, end: str | None = None) -> None:
return self._print_standard_descriptor(
self.stderr, "stderr", *args, enforce_output=True, end=end
)
Expand All @@ -712,7 +776,7 @@ def _print_standard_descriptor(
descriptor_name: str,
*args,
enforce_output: bool = False,
end: Optional[str] = None,
end: str | None = None,
) -> None:
"""
Prints to fd, unless quiet is set.
Expand All @@ -733,7 +797,7 @@ def _print_helper(
descriptor_encoding: str,
descriptor_name: str,
*args,
end: Optional[str] = None
end: str | None = None
):
try:
descriptor.write(' '.join(args))
Expand Down Expand Up @@ -1010,7 +1074,8 @@ def run(self, args):

@B2.register_subcommand
class CopyFileById(
DestinationSseMixin, SourceSseMixin, FileRetentionSettingMixin, LegalHoldMixin, Command
HeaderFlagsMixin, DestinationSseMixin, SourceSseMixin, FileRetentionSettingMixin,
LegalHoldMixin, Command
):
"""
Copy a file version to the given bucket (server-side, **not** via download+upload).
Expand Down Expand Up @@ -1072,6 +1137,7 @@ def run(self, args):
file_infos = self._parse_file_infos(args.info)
elif args.noInfo:
file_infos = {}
file_infos = self._file_info_with_header_args(args, file_infos)

if args.metadataDirective is not None:
self._print_stderr(
Expand Down Expand Up @@ -1115,20 +1181,20 @@ def run(self, args):
self._print_json(file_version)
return 0

def _is_ssec(self, encryption: Optional[EncryptionSetting]):
def _is_ssec(self, encryption: EncryptionSetting | None):
if encryption is not None and encryption.mode == EncryptionMode.SSE_C:
return True
return False

def _determine_source_metadata(
self,
source_file_id: str,
destination_encryption: Optional[EncryptionSetting],
source_encryption: Optional[EncryptionSetting],
target_file_info: Optional[dict],
target_content_type: Optional[str],
destination_encryption: EncryptionSetting | None,
source_encryption: EncryptionSetting | None,
target_file_info: dict | None,
target_content_type: str | None,
fetch_if_necessary: bool,
) -> Tuple[Optional[dict], Optional[str]]:
) -> tuple[dict | None, str | None]:
"""Determine if source file metadata is necessary to perform the copy - due to sse_c_key_id"""
if not self._is_ssec(source_encryption) and not self._is_ssec(
destination_encryption
Expand Down Expand Up @@ -1361,6 +1427,8 @@ def _print_download_info(self, downloaded_file: DownloadedFile):
'Legal hold', self._represent_legal_hold(download_version.legal_hold)
)
for label, attr_name in [
('CacheControl', 'cache_control'),
('Expires', 'expires'),
('ContentDisposition', 'content_disposition'),
('ContentLanguage', 'content_language'),
('ContentEncoding', 'content_encoding'),
Expand Down Expand Up @@ -1997,7 +2065,7 @@ def _print_file_version(
self,
args,
file_version: FileVersion,
folder_name: Optional[str],
folder_name: str | None,
) -> None:
self._print(folder_name or file_version.file_name)

Expand Down Expand Up @@ -2104,7 +2172,7 @@ def _print_file_version(
self,
args,
file_version: FileVersion,
folder_name: Optional[str],
folder_name: str | None,
) -> None:
if not args.long:
super()._print_file_version(args, file_version, folder_name)
Expand Down Expand Up @@ -2226,7 +2294,7 @@ class SubmitThread(threading.Thread):

def __init__(
self,
runner: 'Rm',
runner: Rm,
args: argparse.Namespace,
messages_queue: queue.Queue,
reporter: ProgressReport,
Expand Down Expand Up @@ -2860,6 +2928,7 @@ def _setup_parser(cls, parser):


class UploadFileMixin(
HeaderFlagsMixin,
MinPartSizeMixin,
ThreadsMixin,
ProgressMixin,
Expand Down Expand Up @@ -2888,7 +2957,6 @@ def _setup_parser(cls, parser):
parser.add_argument(
'--sha1', help="SHA-1 of the data being uploaded for verifying file integrity"
)
parser.add_argument('--cache-control', default=None)
parser.add_argument(
'--info',
action='append',
Expand Down Expand Up @@ -2935,11 +3003,11 @@ def get_execute_kwargs(self, args) -> dict:
else:
file_infos[SRC_LAST_MODIFIED_MILLIS] = str(int(mtime * 1000))

file_infos = self._file_info_with_header_args(args, file_infos)

return {
"bucket":
self.api.get_bucket_by_name(args.bucketName),
"cache_control":
args.cache_control,
"content_type":
args.contentType,
"custom_upload_timestamp":
Expand Down Expand Up @@ -2967,7 +3035,7 @@ def get_execute_kwargs(self, args) -> dict:
}

@abstractmethod
def execute_operation(self, **kwargs) -> 'b2sdk.file_version.FileVersion':
def execute_operation(self, **kwargs) -> b2sdk.file_version.FileVersion:
raise NotImplementedError

def upload_file_kwargs_to_unbound_upload(self, **kwargs):
Expand All @@ -2979,7 +3047,7 @@ def upload_file_kwargs_to_unbound_upload(self, **kwargs):
kwargs["read_size"] = kwargs["min_part_size"] or DEFAULT_MIN_PART_SIZE
return kwargs

def get_input_stream(self, filename: str) -> 'str | int | io.BinaryIO':
def get_input_stream(self, filename: str) -> str | int | io.BinaryIO:
"""Get input stream IF filename points to a FIFO or stdin."""
if filename == "-":
if os.path.exists('-'):
Expand All @@ -2993,9 +3061,7 @@ def get_input_stream(self, filename: str) -> 'str | int | io.BinaryIO':

raise self.NotAnInputStream()

def file_identifier_to_read_stream(
self, file_id: 'str | int | BinaryIO', buffering
) -> BinaryIO:
def file_identifier_to_read_stream(self, file_id: str | int | BinaryIO, buffering) -> BinaryIO:
if isinstance(file_id, (str, int)):
return open(
file_id,
Expand Down Expand Up @@ -3319,7 +3385,7 @@ def run(self, args):
return 0

@classmethod
def alter_rule_by_name(cls, bucket: Bucket, name: str) -> Tuple[bool, bool]:
def alter_rule_by_name(cls, bucket: Bucket, name: str) -> tuple[bool, bool]:
""" returns False if rule could not be found """
if not bucket.replication or not bucket.replication.rules:
return False, False
Expand Down Expand Up @@ -3356,7 +3422,7 @@ def alter_rule_by_name(cls, bucket: Bucket, name: str) -> Tuple[bool, bool]:

@classmethod
@abstractmethod
def alter_one_rule(cls, rule: ReplicationRule) -> Optional[ReplicationRule]:
def alter_one_rule(cls, rule: ReplicationRule) -> ReplicationRule | None:
""" return None to delete a rule """
pass

Expand All @@ -3373,7 +3439,7 @@ class ReplicationDelete(ReplicationRuleChanger):
"""

@classmethod
def alter_one_rule(cls, rule: ReplicationRule) -> Optional[ReplicationRule]:
def alter_one_rule(cls, rule: ReplicationRule) -> ReplicationRule | None:
""" return None to delete rule """
return None

Expand All @@ -3390,7 +3456,7 @@ class ReplicationPause(ReplicationRuleChanger):
"""

@classmethod
def alter_one_rule(cls, rule: ReplicationRule) -> Optional[ReplicationRule]:
def alter_one_rule(cls, rule: ReplicationRule) -> ReplicationRule | None:
""" return None to delete rule """
rule.is_enabled = False
return rule
Expand All @@ -3408,7 +3474,7 @@ class ReplicationUnpause(ReplicationRuleChanger):
"""

@classmethod
def alter_one_rule(cls, rule: ReplicationRule) -> Optional[ReplicationRule]:
def alter_one_rule(cls, rule: ReplicationRule) -> ReplicationRule | None:
""" return None to delete rule """
rule.is_enabled = True
return rule
Expand Down Expand Up @@ -3508,9 +3574,9 @@ def run(self, args):

@classmethod
def get_results_for_rule(
cls, bucket: Bucket, rule: ReplicationRule, destination_api: Optional[B2Api],
cls, bucket: Bucket, rule: ReplicationRule, destination_api: B2Api | None,
scan_destination: bool, quiet: bool
) -> List[dict]:
) -> list[dict]:
monitor = ReplicationMonitor(
bucket=bucket,
rule=rule,
Expand All @@ -3527,7 +3593,7 @@ def get_results_for_rule(
]

@classmethod
def filter_results_columns(cls, results: List[dict], columns: List[str]) -> List[dict]:
def filter_results_columns(cls, results: list[dict], columns: list[str]) -> list[dict]:
return [{key: result[key] for key in columns} for result in results]

@classmethod
Expand All @@ -3543,10 +3609,10 @@ def to_human_readable(cls, value: Any) -> str:

return str(value)

def output_json(self, results: Dict[str, List[dict]]) -> None:
def output_json(self, results: dict[str, list[dict]]) -> None:
self._print_json(results)

def output_console(self, results: Dict[str, List[dict]]) -> None:
def output_console(self, results: dict[str, list[dict]]) -> None:
for rule_name, rule_results in results.items():
self._print(f'Replication "{rule_name}":')
rule_results = [
Expand All @@ -3558,7 +3624,7 @@ def output_console(self, results: Dict[str, List[dict]]) -> None:
]
self._print(tabulate(rule_results, headers='keys', tablefmt='grid'))

def output_csv(self, results: Dict[str, List[dict]]) -> None:
def output_csv(self, results: dict[str, list[dict]]) -> None:

rows = []

Expand Down Expand Up @@ -3735,7 +3801,7 @@ def _put_license_text_for_packages(self, stream: io.StringIO):
stream.write(str(summary_table))

@classmethod
def _get_licenses_dicts(cls) -> List[Dict]:
def _get_licenses_dicts(cls) -> list[dict]:
assert piplicenses, 'In order to run this command, you need to install the `license` extra: pip install b2[license]'
pipdeptree_run = subprocess.run(
["pipdeptree", "--json", "-p", "b2"],
Expand Down Expand Up @@ -3842,7 +3908,7 @@ class ConsoleTool:
Uses a ``b2sdk.SqlitedAccountInfo`` object to keep account data between runs.
"""

def __init__(self, b2_api: Optional[B2Api], stdout, stderr):
def __init__(self, b2_api: B2Api | None, stdout, stderr):
self.api = b2_api
self.stdout = stdout
self.stderr = stderr
Expand Down
1 change: 1 addition & 0 deletions changelog.d/+add_header_options.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `--expires`, `--content-disposition`, `--content-encoding`, `--content-language` options to subcommands `upload-file`, `upload-unbound-stream`, `copy-file-by-id`
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
argcomplete>=2,<4
arrow>=1.0.2,<2.0.0
b2sdk>=1.25.0,<2
b2sdk>=1.26.0,<2
docutils>=0.18.1
idna~=3.4; platform_system == 'Java'
importlib-metadata~=3.3; python_version < '3.8'
Expand Down
Loading
Loading