Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
minkyu-choi07 committed Feb 11, 2024
1 parent a6a1457 commit cef85a4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 3 deletions.
2 changes: 1 addition & 1 deletion run_scripts/run_synthetic_tlv_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
type=str,
default="",
)
# - - - - - - Image Loader Arguement - - - - - - #
# - - - - - - Image Loader Argument - - - - - - #
parser.add_argument(
"--coco_image_source",
type=str,
Expand Down
10 changes: 10 additions & 0 deletions tlv_dataset/data/tlv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ def __post_init__(self):
data=self.frames_of_interest
)

def sanity_check(self):
"""Sanity check."""
_labels_of_frames = []
for label in self.labels_of_frames:
if isinstance(label, str):
_labels_of_frames.append([label])
else:
_labels_of_frames.append(label)
self.labels_of_frames = _labels_of_frames

def save_frames(
self, path="/opt/Neuro-Symbolic-Video-Frame-Search/artifacts"
) -> None:
Expand Down
9 changes: 7 additions & 2 deletions tlv_dataset/generator/synthetic_tlv_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def generate(
proposition=ltl_frame.proposition,
)
self.extract_properties(ltl_frame.ltl_formula)
# sanity check
ltl_frame.sanity_check()
if save_as == "dict":
ltl_frame.save_as_dict(
save_path=self._save_dir
Expand Down Expand Up @@ -192,6 +194,7 @@ def generate_until_time_delta(
proposition=ltl_frame.proposition,
)
self.extract_properties(ltl_frame.ltl_formula)
ltl_frame.sanity_check()
if save_as == "dict":
ltl_frame.save_as_dict(
save_path=self._save_dir
Expand Down Expand Up @@ -455,9 +458,11 @@ def ltl_function(
if x not in temp_frames_of_interest
]

# 2. G "prop1"

# TODO: Make a false case
# for i, label in enumerate(labels_of_frame):
# if isinstance(label, str):
# labels_of_frame[i] = [label]

return TLVDataset(
ground_truth=True,
ltl_formula=ltl_formula,
Expand Down

0 comments on commit cef85a4

Please sign in to comment.