Skip to content

Commit

Permalink
feat(pt): add datafile option for change-bias (#3945)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Added optional `--datafile` argument to specify a file for system data
processing.

- **Bug Fixes**
- Improved `help` messages for `--datafile` argument to clarify its
usage.

- **Tests**
- Enhanced test coverage for changing bias with a new method that
handles data from a system file.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
iProzd authored Jul 3, 2024
1 parent 1c3e099 commit 29db791
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
9 changes: 8 additions & 1 deletion deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def main_parser() -> argparse.ArgumentParser:
"--datafile",
default=None,
type=str,
help="The path to file of test list.",
help="The path to the datafile, each line of which is a path to one data system.",
)
parser_tst.add_argument(
"-S",
Expand Down Expand Up @@ -685,6 +685,13 @@ def main_parser() -> argparse.ArgumentParser:
type=str,
help="The system dir. Recursively detect systems in this directory",
)
parser_change_bias_source.add_argument(
"-f",
"--datafile",
default=None,
type=str,
help="The path to the datafile, each line of which is a path to one data system.",
)
parser_change_bias_source.add_argument(
"-b",
"--bias-value",
Expand Down
7 changes: 6 additions & 1 deletion deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,12 @@ def change_bias(FLAGS):
updated_model = model_to_change
else:
# calculate bias on given systems
data_systems = process_systems(expand_sys_str(FLAGS.system))
if FLAGS.datafile is not None:
with open(FLAGS.datafile) as datalist:
all_sys = datalist.read().splitlines()
else:
all_sys = expand_sys_str(FLAGS.system)
data_systems = process_systems(all_sys)
data_single = DpLoaderSet(
data_systems,
1,
Expand Down
34 changes: 32 additions & 2 deletions source/tests/pt/test_change_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import shutil
import tempfile
import unittest
from copy import (
deepcopy,
Expand Down Expand Up @@ -36,6 +37,9 @@
to_torch_tensor,
)

from .common import (
run_dp,
)
from .model.test_permutation import (
model_se_e2_a,
)
Expand Down Expand Up @@ -77,12 +81,15 @@ def setUp(self):
self.model_path_data_bias = Path(current_path) / (
model_name + "data_bias" + ".pt"
)
self.model_path_data_file_bias = Path(current_path) / (
model_name + "data_file_bias" + ".pt"
)
self.model_path_user_bias = Path(current_path) / (
model_name + "user_bias" + ".pt"
)

def test_change_bias_with_data(self):
os.system(
run_dp(
f"dp --pt change-bias {self.model_path!s} -s {self.data_file[0]} -o {self.model_path_data_bias!s}"
)
state_dict = torch.load(str(self.model_path_data_bias), map_location=DEVICE)
Expand All @@ -99,9 +106,32 @@ def test_change_bias_with_data(self):
expected_bias = expected_model.get_out_bias()
torch.testing.assert_close(updated_bias, expected_bias)

def test_change_bias_with_data_sys_file(self):
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
with open(tmp_file.name, "w") as f:
f.writelines([sys + "\n" for sys in self.data_file])
run_dp(
f"dp --pt change-bias {self.model_path!s} -f {tmp_file.name} -o {self.model_path_data_file_bias!s}"
)
state_dict = torch.load(
str(self.model_path_data_file_bias), map_location=DEVICE
)
model_params = state_dict["model"]["_extra_state"]["model_params"]
model_for_wrapper = get_model_for_wrapper(model_params)
wrapper = ModelWrapper(model_for_wrapper)
wrapper.load_state_dict(state_dict["model"])
updated_bias = wrapper.model["Default"].get_out_bias()
expected_model = model_change_out_bias(
self.trainer.wrapper.model["Default"],
self.sampled,
_bias_adjust_mode="change-by-statistic",
)
expected_bias = expected_model.get_out_bias()
torch.testing.assert_close(updated_bias, expected_bias)

def test_change_bias_with_user_defined(self):
user_bias = [0.1, 3.2, -0.5]
os.system(
run_dp(
f"dp --pt change-bias {self.model_path!s} -b {' '.join([str(_) for _ in user_bias])} -o {self.model_path_user_bias!s}"
)
state_dict = torch.load(str(self.model_path_user_bias), map_location=DEVICE)
Expand Down

0 comments on commit 29db791

Please sign in to comment.