Skip to content

Commit

Permalink
Improved FileIO
Browse files Browse the repository at this point in the history
Now the file is kept open to reduce the cost of opening and closing the MMap file.
  • Loading branch information
GMW99 committed Jul 4, 2024
1 parent 25c5fcd commit 6aab9ca
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions examples/mvp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@
"id": "eac8d73df27f4355",
"metadata": {},
"source": [
"def decode_mrc_data_path(data_paths: DataPathFrame) -> np.array:\n",
"def decode_mrc_data_path(file,data_paths) -> np.array:\n",
" \"\"\"\n",
" Decode the MRC data path returning the slice of the data speficifed.\n",
"\n",
Expand All @@ -591,10 +591,10 @@
" A Numpy array containing the decoded data along with an additional axis to create Height, Width, Channel.\n",
"\n",
" Examples:\n",
" If we have a Data path array of (\"/tmp/0.mrc\",0) with shape (5,5) this will\n",
" If we have an Data path array of (\"/tmp/0.mrc\",0) with shape (5,5) this will\n",
" return the numpy array with the shape (5,5,1)\n",
" \"\"\"\n",
" mrc_file = mrcfile.mmap(data_paths[0])\n",
" mrc_file = file\n",
" frame_index = np.argmax(mrc_file.data.shape)\n",
" if len(mrc_file.data.shape) == 3:\n",
" if frame_index == 0:\n",
Expand All @@ -616,7 +616,7 @@
" A list of paths and frames.\n",
"\n",
" Example:\n",
" If we have an MRCfile shape 1,2,3 at /tmp/0.mrc then::\n",
" If we have an MRC file shape 1,2,3 at /tmp/0.mrc then::\n",
"\n",
" data = get_data_paths_and_frames(\"/tmp/0.mrc\")\n",
"\n",
Expand Down Expand Up @@ -654,6 +654,8 @@
" transform: typing.Optional[torchvision.transforms.Compose] = None,\n",
" ):\n",
" self.dataset = dataset\n",
" \n",
" self.data_file = mrcfile.mmap(dataset[0][0], mode=\"r\")\n",
" self.transform = transform\n",
"\n",
" def __len__(self) -> int:\n",
Expand All @@ -674,7 +676,7 @@
" \"\"\"\n",
" if torch.is_tensor(idx):\n",
" idx = idx.tolist()\n",
" frame = decode_mrc_data_path(self.dataset[idx])\n",
" frame = decode_mrc_data_path(self.data_file, self.dataset[idx])\n",
" if self.transform:\n",
" frame = self.transform(frame)\n",
" return frame"
Expand Down Expand Up @@ -904,11 +906,11 @@
"metadata": {},
"source": [
"config[\"learning_rate\"] = 1e-5\n",
"config[\"epochs\"] = 10\n",
"config[\"epochs\"] = 5\n",
"config[\"dataset_name\"] = \"name\"\n",
"config[\"seed\"] = 0\n",
"config[\"padding\"] = 1\n",
"config[\"batch_size\"] = 32\n",
"config[\"batch_size\"] = 64\n",
"config[\"levels\"] = 1\n",
"config[\"test_split\"] = 0.2\n",
"config[\"model_name\"] = \"ConvSRM\"\n",
Expand Down

0 comments on commit 6aab9ca

Please sign in to comment.