Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhong-al committed Dec 14, 2024
1 parent 5c59afb commit d5c5dcf
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions tests/test_miniscene2behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests
from unittest.mock import Mock, patch
import torch
from lxml import etree
import numpy as np
import pandas as pd
from kabr_tools import (
Expand Down Expand Up @@ -97,12 +98,13 @@ def test_run(self):
@patch('kabr_tools.miniscene2behavior.process_cv2_inputs')
@patch('kabr_tools.miniscene2behavior.cv2.VideoCapture')
def test_matching_tracks(self, video_capture, process_cv2_inputs):

# Create fake model that always returns a prediction of 1
# create fake model that weights class 98
mock_model = Mock()
mock_model.return_value = torch.tensor([1])
prob = torch.zeros(99)
prob[-1] = 1
mock_model.return_value = prob

# Create fake cfg
# create fake cfg
mock_config = Mock(
DATA=Mock(NUM_FRAMES=16,
SAMPLING_RATE=5,
Expand All @@ -111,25 +113,36 @@ def test_matching_tracks(self, video_capture, process_cv2_inputs):
OUTPUT_DIR=''
)

# Create fake video capture
# create fake video capture
vc = video_capture.return_value
vc.read.return_value = True, np.zeros((8, 8, 3), np.uint8)
vc.get.return_value = 1
vc.get.return_value = 21

self.output = '/tmp/annotation_data.csv'
miniscene_dir = os.path.join(EXAMPLESDIR, "MINISCENE1")
video_name = "DJI"

annotate_miniscene(cfg=mock_config,
model=mock_model,
miniscene_path=os.path.join(
EXAMPLESDIR, "MINISCENE1"),
video='DJI',
miniscene_path=miniscene_dir,
video=video_name,
output_path=self.output)

# Read in output CSV and make sure we have the expected columns and at least one row
# check output CSV
df = pd.read_csv(self.output, sep=' ')
self.assertEqual(list(df.columns), [
"video", "track", "frame", "label"])
self.assertGreater(len(df.index), 0)
row_ct = 0

root = etree.parse(
f"{miniscene_dir}/metadata/DJI_tracks.xml").getroot()
for track in root.iterfind("track"):
track_id = int(track.get("id"))
for box in track.iterfind("box"):
row_val = [video_name, track_id, int(box.get("frame")), 98]
self.assertEqual(list(df.loc[row_ct]), row_val)
row_ct += 1
self.assertEqual(len(df.index), row_ct)

@patch('kabr_tools.miniscene2behavior.process_cv2_inputs')
@patch('kabr_tools.miniscene2behavior.cv2.VideoCapture')
Expand All @@ -151,9 +164,11 @@ def test_nonmatching_tracks(self, video_capture, process_cv2_inputs):
# Create fake video capture
vc = video_capture.return_value
vc.read.return_value = True, np.zeros((8, 8, 3), np.uint8)
vc.get.return_value = 1
vc.get.return_value = 21

self.output = '/tmp/annotation_data.csv'
miniscene_dir = os.path.join(EXAMPLESDIR, "MINISCENE2")
video_name = "DJI"

annotate_miniscene(cfg=mock_config,
model=mock_model,
Expand All @@ -162,11 +177,22 @@ def test_nonmatching_tracks(self, video_capture, process_cv2_inputs):
video='DJI',
output_path=self.output)

# Read in output CSV and make sure we have the expected columns and at least one row
# check output CSV
df = pd.read_csv(self.output, sep=' ')
self.assertEqual(list(df.columns), [
"video", "track", "frame", "label"])
self.assertGreater(len(df.index), 0)
row_ct = 0

root = etree.parse(
f"{miniscene_dir}/metadata/DJI_tracks.xml").getroot()
for track in root.iterfind("track"):
track_id = int(track.get("id"))
for box in track.iterfind("box"):
row_val = [video_name, track_id, int(box.get("frame")), 0]
self.assertEqual(list(df.loc[row_ct]), row_val)
row_ct += 1
self.assertEqual(len(df.index), row_ct)


def test_parse_arg_min(self):
# parse arguments
Expand Down

0 comments on commit d5c5dcf

Please sign in to comment.