Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 21, 2024
1 parent 2a38025 commit e5c0ecf
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
10 changes: 6 additions & 4 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)
from deepmd.pt.utils.utils import (
ActivationFn,
)
Expand All @@ -41,9 +44,6 @@
from .repformer_layer import (
RepformerLayer,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)

if not hasattr(torch.ops.deepmd, "border_op"):

Expand Down Expand Up @@ -493,7 +493,9 @@ def forward(
g1_ext = ret[0].unsqueeze(0)
if has_spin:
g1_real_ext, g1_virtual_ext = torch.split(g1_ext, [ng1, ng1], dim=2)
g1_ext = concat_switch_virtual(g1_real_ext, g1_virtual_ext, real_nloc)
g1_ext = concat_switch_virtual(
g1_real_ext, g1_virtual_ext, real_nloc
)
g1, g2, h2 = ll.forward(
g1_ext,
g2,
Expand Down
6 changes: 3 additions & 3 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)
from deepmd.pt.utils.utils import (
to_torch_tensor,
)
Expand All @@ -24,9 +27,6 @@
from deepmd.utils.spin import (
Spin,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)

from .make_model import (
make_model,
Expand Down
16 changes: 8 additions & 8 deletions deepmd/pt/utils/spin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Optional
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

import torch


Expand All @@ -24,12 +28,8 @@ def concat_switch_virtual(
device=extended_tensor.device,
)
extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc]
extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[
:, :nloc
]
extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[
:, nloc:
]
extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[:, :nloc]
extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[:, nloc:]
extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:]
# nloc + nloc + nghost + nghost
if recv_num is not None:
Expand Down Expand Up @@ -59,4 +59,4 @@ def concat_switch_virtual(
] = extended_tensor_virtual[
:, nloc + origin_prefix_sum[i] : nloc + origin_prefix_sum[i + 1]
]
return extended_tensor_updated.view(out_shape)
return extended_tensor_updated.view(out_shape)

0 comments on commit e5c0ecf

Please sign in to comment.