Skip to content

Commit

Permalink
update GRegionsTree subtract with strandness
Browse files Browse the repository at this point in the history
  • Loading branch information
Joseph Kuo committed Mar 8, 2024
1 parent 2ecb284 commit 693043b
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 53 deletions.
110 changes: 105 additions & 5 deletions genomkit/regions/gregions_intervaltree.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def sampling(self, size: int, seed: int = None):
res = GRegionsTree(name="sampling")
sampling = random.sample(range(len(self)), size)
for i in sampling:
res.add(self.elements[i])
res.add(self[i])
return res

def split(self, ratio: float, size: int = None, seed: int = None):
Expand Down Expand Up @@ -534,7 +534,7 @@ def overlap_count(self, target):
return len(intersect)

def subtract(self, regions, whole_region: bool = False,
merge: bool = True, exact: bool = False,
strandness: bool = False, exact: bool = False,
inplace: bool = True):
"""Subtract regions from the self regions.
Expand All @@ -543,8 +543,8 @@ def subtract(self, regions, whole_region: bool = False,
:param whole_region: Subtract the whole region, not partially,
defaults to False
:type whole_region: bool, default to False
:param merge: Merging the regions before subtracting
:type merge: bool, default to True
:param strandness: Define whether strandness is considered.
:type strandness: bool
:param exact: Only regions which match exactly with a given region are
subtracted. If True, whole_region and merge are
completely ignored and the returned GRegions is sorted
Expand All @@ -562,7 +562,107 @@ def subtract(self, regions, whole_region: bool = False,
regions ---------- ----
Result ------- ------
"""
pass
assert isinstance(regions, GRegionsTree)

def remain_interval(seq, begin, end, interval, res):
gr = GRegion(sequence=seq, start=begin, end=end,
name=interval.data.name, score=interval.data.score,
orientation=interval.data.orientation)
remain = Interval(begin, end, gr)
res.elements[seq].add(remain)

res = GRegionsTree(name=self.name)
for seq in self.elements.keys():
for interval in self.elements[seq]:
# Check if exact match is required
if exact:
# If interval not found in regions, add it to result
if interval not in regions.elements[seq]:
res.elements[seq].add(interval)
elif whole_region:
if any(regions.elements[seq].overlap(
interval.begin, interval.end)):
continue
elif strandness:
for r in regions.elements[seq].overlap(
interval.begin, interval.end):
# interval -----
# r --------------
# remain
if interval.begin > r.begin and \
interval.end < r.end and \
interval.data.orientation ==\
r.data.orientation:
continue
# interval --------------
# r -----
# remain ---- -----
elif interval.begin < r.begin and \
interval.end > r.end and \
interval.data.orientation ==\
r.data.orientation:
remain_interval(seq, interval.begin, r.begin,
interval, res)
remain_interval(seq, r.end, interval.end,
interval, res)
# interval -------
# r -------
# remain ----
elif interval.begin < r.begin and \
interval.end <= r.end and \
interval.data.orientation ==\
r.data.orientation:
remain_interval(seq, interval.begin, r.begin,
interval, res)
# interval -------
# r -------
# remain ---
elif interval.begin >= r.begin and \
interval.end > r.end and \
interval.data.orientation ==\
r.data.orientation:
remain_interval(seq, r.end, interval.end,
interval, res)
else:
res.elements[seq].add(interval)
else:
for r in regions.elements[seq].overlap(
interval.begin, interval.end):
# interval -----
# r --------------
# remain
if interval.begin > r.begin and \
interval.end < r.end:
continue
# interval --------------
# r -----
# remain ---- -----
elif interval.begin < r.begin and \
interval.end > r.end:
remain_interval(seq, interval.begin, r.begin,
interval, res)
remain_interval(seq, r.end, interval.end,
interval, res)
# interval -------
# r -------
# remain ----
elif interval.begin < r.begin and \
interval.end <= r.end:
remain_interval(seq, interval.begin, r.begin,
interval, res)
# interval -------
# r -------
# remain ---
elif interval.begin >= r.begin and \
interval.end > r.end:
remain_interval(seq, r.end, interval.end,
interval, res)
else:
res.elements[seq].add(interval)
if inplace:
self.elements = res.elements
else:
return res

def get_GSequences(self, FASTA_file):
"""Return a GSequences object according to the loci on the given
Expand Down
97 changes: 49 additions & 48 deletions tests/test_gregionstree.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,6 @@ def test_intersect(self):
intersect = regions1.intersect(regions2, mode='COMP_INCL')
self.assertEqual(len(intersect), 0)

# def test_intersect_array(self):
# regions1 = GRegionsTree(name="test")
# regions1.load(filename=os.path.join(script_path,
# "test_files/bed/example.bed"))
# regions2 = GRegionsTree(name="test")
# regions2.load(filename=os.path.join(script_path,
# "test_files/bed/example2.bed"))
# intersect = regions1.intersect_array(regions2)
# self.assertEqual(len(intersect), 4)

def test_merge(self):
regions = GRegionsTree(name="test")
regions.load(filename=os.path.join(script_path,
Expand All @@ -134,44 +124,55 @@ def test_remove_duplicates(self):
regions.remove_duplicates()
self.assertEqual(len(regions), 4)

# def test_sort(self):
# regions = GRegionsTree(name="test")
# regions.load(filename=os.path.join(script_path,
# "test_files/bed/example4.bed"))
# regions.sort()
## Duplicate interval should be kept
# self.assertEqual(len(regions), 6)
# self.assertEqual(regions.get_sequences(),
# ["chr1", "chr1", "chr1", "chr2", "chr2", "chr2"])

# def test_sampling(self):
# regions = load_BED(filename=os.path.join(script_path,
# "test_files/bed/example4.bed"))
# sampling = regions.sampling(size=3)
# self.assertEqual(len(sampling), 3)

# def test_subtract(self):
# regions1 = GRegionsTree(name="test")
# regions1.load(filename=os.path.join(script_path,
# "test_files/bed/example.bed"))
# regions2 = GRegionsTree(name="test")
# regions2.load(filename=os.path.join(script_path,
# "test_files/bed/example2.bed"))
# regions1.subtract(regions2)
# self.assertEqual(len(regions1), 4)
# self.assertEqual(len(regions1[0]), 500)
# self.assertEqual(len(regions1[1]), 500)
# self.assertEqual(len(regions1[2]), 500)
# self.assertEqual(len(regions1[3]), 500)

# # regions1 = GRegionsTree(name="test")
# # regions1.load(filename=os.path.join(script_path,
# # "test_files/bed/example.bed"))
# # regions2 = GRegionsTree(name="test")
# # regions2.load(filename=os.path.join(script_path,
# # "test_files/bed/example2.bed"))
# # regions1.subtract(regions2, whole_region=True)
# # self.assertEqual(len(regions1[0]), 0)
def test_sort(self):
regions = GRegionsTree(name="test")
regions.load(filename=os.path.join(script_path,
"test_files/bed/example4.bed"))
regions.sort()
self.assertEqual(len(regions), 4)
self.assertEqual(sorted(regions.get_sequences()),
["chr1", "chr1", "chr2", "chr2"])

def test_sampling(self):
regions = GRegionsTree(
load=os.path.join(script_path, "test_files/bed/example4.bed"))
sampling = regions.sampling(size=3)
self.assertEqual(len(sampling), 3)

def test_subtract(self):
regions1 = GRegionsTree(name="test")
regions1.load(filename=os.path.join(script_path,
"test_files/bed/example.bed"))
regions2 = GRegionsTree(name="test")
regions2.load(filename=os.path.join(script_path,
"test_files/bed/example2.bed"))
regions1.subtract(regions2)
for r in regions1:
print(r)
self.assertEqual(len(regions1), 4)
self.assertEqual(len(regions1[0]), 500)
self.assertEqual(len(regions1[1]), 500)
self.assertEqual(len(regions1[2]), 500)
self.assertEqual(len(regions1[3]), 500)
regions1 = GRegionsTree(name="test")
regions1.load(filename=os.path.join(script_path,
"test_files/bed/example.bed"))
regions1.subtract(regions2, strandness=True)
for r in regions1:
print(r)
self.assertEqual(len(regions1), 4)
self.assertEqual(len(regions1[0]), 500)
self.assertEqual(len(regions1[1]), 1000)
self.assertEqual(len(regions1[2]), 1000)
self.assertEqual(len(regions1[3]), 500)
# regions1 = GRegionsTree(name="test")
# regions1.load(filename=os.path.join(script_path,
# "test_files/bed/example.bed"))
# regions2 = GRegionsTree(name="test")
# regions2.load(filename=os.path.join(script_path,
# "test_files/bed/example2.bed"))
# regions1.subtract(regions2, whole_region=True)
# self.assertEqual(len(regions1[0]), 0)

# def test_total_coverage(self):
# regions1 = GRegionsTree(name="test")
Expand Down

0 comments on commit 693043b

Please sign in to comment.