Skip to content

Commit

Permalink
Merge pull request #869 from lsst/tickets/DM-40156
Browse files Browse the repository at this point in the history
DM-40156: Some code cleanups
  • Loading branch information
timj authored Jul 24, 2023
2 parents ee6d72d + 05bb4df commit 0787220
Show file tree
Hide file tree
Showing 76 changed files with 427 additions and 407 deletions.
52 changes: 23 additions & 29 deletions python/lsst/daf/butler/_butler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,7 @@ def __init__(
else:
self._config = ButlerConfig(config, searchPaths=searchPaths, without_datastore=without_datastore)
try:
if "root" in self._config:
butlerRoot = self._config["root"]
else:
butlerRoot = self._config.configDir
butlerRoot = self._config.get("root", self._config.configDir)
if writeable is None:
writeable = run is not None
self._registry = _RegistryFactory(self._config).from_config(
Expand Down Expand Up @@ -404,7 +401,7 @@ def makeRepo(
construct the repository should also be used to construct any Butlers
to avoid configuration inconsistencies.
"""
if isinstance(config, (ButlerConfig, ConfigSubset)):
if isinstance(config, ButlerConfig | ConfigSubset):
raise ValueError("makeRepo must be passed a regular Config without defaults applied.")

# Ensure that the root of the repository exists or can be made
Expand Down Expand Up @@ -552,9 +549,8 @@ def transaction(self) -> Iterator[None]:
Transactions can be nested.
"""
with self._registry.transaction():
with self._datastore.transaction():
yield
with self._registry.transaction(), self._datastore.transaction():
yield

def _standardizeArgs(
self,
Expand Down Expand Up @@ -1187,7 +1183,7 @@ def put(
try:
self._datastore.put(obj, datasetRefOrType)
except IntegrityError as e:
raise ConflictingDefinitionError(f"Datastore already contains dataset: {e}")
raise ConflictingDefinitionError(f"Datastore already contains dataset: {e}") from e
return datasetRefOrType

log.debug("Butler put: %s, dataId=%s, run=%s", datasetRefOrType, dataId, run)
Expand Down Expand Up @@ -1736,7 +1732,7 @@ def _exists_many(
existence[ref] |= DatasetExistence._ARTIFACT
else:
# Do not set this flag if nothing is known about the dataset.
for ref in existence.keys():
for ref in existence:
if existence[ref] != DatasetExistence.UNRECOGNIZED:
existence[ref] |= DatasetExistence._ASSUMED

Expand Down Expand Up @@ -1827,14 +1823,13 @@ def removeRuns(self, names: Iterable[str], unstore: bool = True) -> None:
if collectionType is not CollectionType.RUN:
raise TypeError(f"The collection type of '{name}' is {collectionType.name}, not RUN.")
refs.extend(self._registry.queryDatasets(..., collections=name, findFirst=True))
with self._datastore.transaction():
with self._registry.transaction():
if unstore:
self._datastore.trash(refs)
else:
self._datastore.forget(refs)
for name in names:
self._registry.removeCollection(name)
with self._datastore.transaction(), self._registry.transaction():
if unstore:
self._datastore.trash(refs)
else:
self._datastore.forget(refs)
for name in names:
self._registry.removeCollection(name)
if unstore:
# Point of no return for removing artifacts
self._datastore.emptyTrash()
Expand Down Expand Up @@ -1882,16 +1877,15 @@ def pruneDatasets(
# mutating the Registry (it can _look_ at Datastore-specific things,
# but shouldn't change them), and hence all operations here are
# Registry operations.
with self._datastore.transaction():
with self._registry.transaction():
if unstore:
self._datastore.trash(refs)
if purge:
self._registry.removeDatasets(refs)
elif disassociate:
assert tags, "Guaranteed by earlier logic in this function."
for tag in tags:
self._registry.disassociate(tag, refs)
with self._datastore.transaction(), self._registry.transaction():
if unstore:
self._datastore.trash(refs)
if purge:
self._registry.removeDatasets(refs)
elif disassociate:
assert tags, "Guaranteed by earlier logic in this function."
for tag in tags:
self._registry.disassociate(tag, refs)
# We've exited the Registry transaction, and apparently committed.
# (if there was an exception, everything rolled back, and it's as if
# nothing happened - and we never get here).
Expand Down Expand Up @@ -2079,7 +2073,7 @@ def ingest(
*datasets, transfer=transfer, record_validation_info=record_validation_info
)
except IntegrityError as e:
raise ConflictingDefinitionError(f"Datastore already contains one or more datasets: {e}")
raise ConflictingDefinitionError(f"Datastore already contains one or more datasets: {e}") from e

@contextlib.contextmanager
def export(
Expand Down
7 changes: 3 additions & 4 deletions python/lsst/daf/butler/_butlerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

__all__ = ("ButlerConfig",)

import contextlib
import copy
import os
from collections.abc import Sequence
Expand Down Expand Up @@ -86,19 +87,17 @@ def __init__(
original_other = other
resolved_alias = False
if isinstance(other, str):
try:
with contextlib.suppress(Exception):
# Force back to a string because the resolved URI
# might not refer explicitly to a directory and we have
# check below to guess that.
other = str(ButlerRepoIndex.get_repo_uri(other, True))
except Exception:
pass
if other != original_other:
resolved_alias = True

# Include ResourcePath here in case it refers to a directory.
# Creating a ResourcePath from a ResourcePath is a no-op.
if isinstance(other, (str, os.PathLike, ResourcePath)):
if isinstance(other, str | os.PathLike | ResourcePath):
# This will only allow supported schemes
uri = ResourcePath(other)

Expand Down
9 changes: 3 additions & 6 deletions python/lsst/daf/butler/_quantum_backed.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,7 @@ def _initialize(
Mapping of the dataset type name to its registry definition.
"""
butler_config = ButlerConfig(config, searchPaths=search_paths)
if "root" in butler_config:
butler_root = butler_config["root"]
else:
butler_root = butler_config.configDir
butler_root = butler_config.get("root", butler_config.configDir)
db = SqliteDatabase.fromUri(f"sqlite:///{filename}", origin=0)
with db.declareStaticTables(create=True) as context:
opaque_manager = OpaqueManagerClass.initialize(db, context)
Expand Down Expand Up @@ -574,7 +571,7 @@ def extract_provenance_data(self) -> QuantumProvenanceData:
)
self._actual_inputs -= self._unavailable_inputs
checked_inputs = self._available_inputs | self._unavailable_inputs
if not self._predicted_inputs == checked_inputs:
if self._predicted_inputs != checked_inputs:
_LOG.warning(
"Execution harness did not check predicted inputs %s for existence; available inputs "
"recorded in provenance may be incomplete.",
Expand Down Expand Up @@ -695,7 +692,7 @@ def collect_and_transfer(
"""
grouped_refs = defaultdict(list)
summary_records: dict[str, DatastoreRecordData] = {}
for quantum, provenance_for_quantum in zip(quanta, provenance):
for quantum, provenance_for_quantum in zip(quanta, provenance, strict=True):
quantum_refs_by_id = {
ref.id: ref
for ref in itertools.chain.from_iterable(quantum.outputs.values())
Expand Down
3 changes: 1 addition & 2 deletions python/lsst/daf/butler/cli/cmd/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

__all__ = ()

import warnings
from typing import Any

import click
Expand Down Expand Up @@ -115,7 +114,7 @@ def butler_import(*args: Any, **kwargs: Any) -> None:
# `reuse_ids`` is not used by `butlerImport`.
reuse_ids = kwargs.pop("reuse_ids", False)
if reuse_ids:
warnings.warn("--reuse-ids option is deprecated and will be removed after v26.", FutureWarning)
click.echo("WARNING: --reuse-ids option is deprecated and will be removed after v26.")

script.butlerImport(*args, **kwargs)

Expand Down
5 changes: 3 additions & 2 deletions python/lsst/daf/butler/cli/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
__all__ = ("ClickProgressHandler",)

from collections.abc import Iterable
from typing import Any, ContextManager, TypeVar
from contextlib import AbstractContextManager
from typing import Any, TypeVar

import click

Expand Down Expand Up @@ -75,6 +76,6 @@ def option(cls, cmd: Any) -> Any:

def get_progress_bar(
self, iterable: Iterable[_T] | None, desc: str | None, total: int | None, level: int
) -> ContextManager[ProgressBar[_T]]:
) -> AbstractContextManager[ProgressBar[_T]]:
# Docstring inherited.
return click.progressbar(iterable=iterable, length=total, label=desc, **self._kwargs) # type: ignore
17 changes: 8 additions & 9 deletions python/lsst/daf/butler/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,11 @@ def split_commas(
stacklevel=2,
)
in_parens = False
elif c == ",":
if not in_parens:
# Split on this comma.
valueList.append(current)
current = ""
continue
elif c == "," and not in_parens:
# Split on this comma.
valueList.append(current)
current = ""
continue
current += c
if in_parens:
warnings.warn(
Expand Down Expand Up @@ -479,7 +478,7 @@ def get(self) -> tuple[tuple[str, str], ...]:
raise click.ClickException(
message=f"Could not parse key-value pair '{val}' using separator '{separator}', "
f"with multiple values {'allowed' if multiple else 'not allowed'}: {e}"
)
) from None
ret.add(k, norm(v))
return ret.get()

Expand Down Expand Up @@ -844,7 +843,7 @@ def _capture_args(self, ctx: click.Context, args: list[str]) -> None:
if (opt := opts[param_name]) is not None:
captured_args.append(opt)
else:
assert False # All parameters should be an Option or an Argument.
raise AssertionError("All parameters should be an Option or an Argument")
MWCtxObj.getFrom(ctx).args = captured_args

def parse_args(self, ctx: click.Context, args: Any) -> list[str]:
Expand Down Expand Up @@ -1030,7 +1029,7 @@ def _name_for_option(ctx: click.Context, option: str) -> str:
option_name=param,
message=f"Error reading overrides file: {e}",
ctx=ctx,
)
) from None
# Override the defaults for this subcommand
ctx.default_map.update(overrides)
return
Expand Down
16 changes: 7 additions & 9 deletions python/lsst/daf/butler/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,11 @@ def __init__(self, other: ResourcePathExpression | Config | Mapping[str, Any] |
# fail. Safer to use update().
self.update(other._data)
self.configFile = other.configFile
elif isinstance(other, (dict, Mapping)):
elif isinstance(other, dict | Mapping):
# In most cases we have a dict, and it's more efficient
# to check for a dict instance before checking the generic mapping.
self.update(other)
elif isinstance(other, (str, ResourcePath, Path)):
elif isinstance(other, str | ResourcePath | Path):
# if other is a string, assume it is a file path/URI
self.__initFromUri(other)
self._processExplicitIncludes()
Expand Down Expand Up @@ -417,7 +417,7 @@ def __initFromJson(self, stream: IO | str | bytes) -> Config:
TypeError:
Raised if there is an error loading the content.
"""
if isinstance(stream, (bytes, str)):
if isinstance(stream, bytes | str):
content = json.loads(stream)
else:
content = json.load(stream)
Expand Down Expand Up @@ -640,7 +640,7 @@ def __getitem__(self, name: _ConfigKey) -> Any:

# In most cases we have a dict, and it's more efficient
# to check for a dict instance before checking the generic mapping.
if isinstance(data, (dict, Mapping)):
if isinstance(data, dict | Mapping):
data = Config(data)
# Ensure that child configs inherit the parent internal delimiter
if self._D != Config._D:
Expand Down Expand Up @@ -763,7 +763,7 @@ def getKeysAsTuples(
val = d[key]
levelKey = base + (key,) if base is not None else (key,)
keys.append(levelKey)
if isinstance(val, (Mapping, Sequence)) and not isinstance(val, str):
if isinstance(val, Mapping | Sequence) and not isinstance(val, str):
getKeysAsTuples(val, keys, levelKey)

keys: list[tuple[str, ...]] = []
Expand Down Expand Up @@ -863,9 +863,7 @@ def asArray(self, name: str | Sequence[str]) -> Sequence[Any]:
will be returned, else the value will be the first element.
"""
val = self.get(name)
if isinstance(val, str):
val = [val]
elif not isinstance(val, Sequence):
if isinstance(val, str) or not isinstance(val, Sequence):
val = [val]
return val

Expand Down Expand Up @@ -1301,7 +1299,7 @@ def _updateWithConfigsFromPath(
# Reverse order so that high priority entries
# update the object last.
for pathDir in reversed(searchPaths):
if isinstance(pathDir, (str, ResourcePath)):
if isinstance(pathDir, str | ResourcePath):
pathDir = ResourcePath(pathDir, forceDirectory=True)
file = pathDir.join(configFile)
if file.exists():
Expand Down
4 changes: 2 additions & 2 deletions python/lsst/daf/butler/core/datasets/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(
self._dimensions = dimensions
if name in self._dimensions.universe.getGovernorDimensions().names:
raise ValueError(f"Governor dimension name {name} cannot be used as a dataset type name.")
if not isinstance(storageClass, (StorageClass, str)):
if not isinstance(storageClass, StorageClass | str):
raise ValueError(f"StorageClass argument must be StorageClass or str. Got {storageClass}")
self._storageClass: StorageClass | None
if isinstance(storageClass, StorageClass):
Expand All @@ -210,7 +210,7 @@ def __init__(
self._parentStorageClass: StorageClass | None = None
self._parentStorageClassName: str | None = None
if parentStorageClass is not None:
if not isinstance(storageClass, (StorageClass, str)):
if not isinstance(storageClass, StorageClass | str):
raise ValueError(
f"Parent StorageClass argument must be StorageClass or str. Got {parentStorageClass}"
)
Expand Down
6 changes: 2 additions & 4 deletions python/lsst/daf/butler/core/datastoreCacheManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ def known_to_cache(self, ref: DatasetRef, extension: str | None = None) -> bool:
# Look solely for matching dataset ref ID and not specific
# components.
cached_paths = self._cache_entries.get_dataset_keys(ref.id)
return True if cached_paths else False
return bool(cached_paths)

else:
# Extension is known so we can do an explicit look up for the
Expand All @@ -885,10 +885,8 @@ def _remove_from_cache(self, cache_entries: Iterable[str]) -> None:

self._cache_entries.pop(entry, None)
log.debug("Removing file from cache: %s", path)
try:
with contextlib.suppress(FileNotFoundError):
path.remove()
except FileNotFoundError:
pass

def _expire_cache(self) -> None:
"""Expire the files in the cache.
Expand Down
17 changes: 7 additions & 10 deletions python/lsst/daf/butler/core/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def decorated(self: Any, config: Config, *args: Any, **kwargs: Any) -> Any:
try:
return func(self, config, *args, **kwargs)
except caught as err:
raise cls(message.format(config=str(config), err=err))
raise cls(message.format(config=str(config), err=err)) from err

return decorated

Expand All @@ -118,7 +118,7 @@ class Base64Bytes(sqlalchemy.TypeDecorator):

def __init__(self, nbytes: int | None = None, *args: Any, **kwargs: Any):
if nbytes is not None:
length = 4 * ceil(nbytes / 3) if self.impl == sqlalchemy.String else None
length = 4 * ceil(nbytes / 3) if self.impl is sqlalchemy.String else None
else:
length = None
super().__init__(*args, length=length, **kwargs)
Expand Down Expand Up @@ -249,9 +249,7 @@ def process_bind_param(self, value: Any, dialect: sqlalchemy.Dialect) -> str | N
def process_result_value(
self, value: str | uuid.UUID | None, dialect: sqlalchemy.Dialect
) -> uuid.UUID | None:
if value is None:
return value
elif isinstance(value, uuid.UUID):
if value is None or isinstance(value, uuid.UUID):
# sqlalchemy 2 converts to UUID internally
return value
else:
Expand Down Expand Up @@ -401,10 +399,9 @@ def isStringType(self) -> bool:
string type if it has been decided that it should be implemented
as a `sqlalchemy.Text` type.
"""
if self.dtype == sqlalchemy.String:
# For short strings retain them as strings
if self.dtype == sqlalchemy.String and self.length and self.length <= 32:
return True
# For short strings retain them as strings
if self.dtype is sqlalchemy.String and self.length and self.length <= 32:
return True
return False

def getSizedColumnType(self) -> sqlalchemy.types.TypeEngine | type:
Expand All @@ -419,7 +416,7 @@ def getSizedColumnType(self) -> sqlalchemy.types.TypeEngine | type:
"""
if self.length is not None:
# Last chance check that we are only looking at possible String
if self.dtype == sqlalchemy.String and not self.isStringType():
if self.dtype is sqlalchemy.String and not self.isStringType():
return sqlalchemy.Text
return self.dtype(length=self.length)
if self.nbytes is not None:
Expand Down
Loading

0 comments on commit 0787220

Please sign in to comment.