From 693043bd443a288d41afb33963990b9ddb343797 Mon Sep 17 00:00:00 2001 From: Joseph Kuo Date: Fri, 8 Mar 2024 11:00:49 +0100 Subject: [PATCH] update GRegionsTree subtract with strandness --- genomkit/regions/gregions_intervaltree.py | 110 +++++++++++++++++++++- tests/test_gregionstree.py | 97 +++++++++---------- 2 files changed, 154 insertions(+), 53 deletions(-) diff --git a/genomkit/regions/gregions_intervaltree.py b/genomkit/regions/gregions_intervaltree.py index 0589bbc..e295dfc 100644 --- a/genomkit/regions/gregions_intervaltree.py +++ b/genomkit/regions/gregions_intervaltree.py @@ -448,7 +448,7 @@ def sampling(self, size: int, seed: int = None): res = GRegionsTree(name="sampling") sampling = random.sample(range(len(self)), size) for i in sampling: - res.add(self.elements[i]) + res.add(self[i]) return res def split(self, ratio: float, size: int = None, seed: int = None): @@ -534,7 +534,7 @@ def overlap_count(self, target): return len(intersect) def subtract(self, regions, whole_region: bool = False, - merge: bool = True, exact: bool = False, + strandness: bool = False, exact: bool = False, inplace: bool = True): """Subtract regions from the self regions. @@ -543,8 +543,8 @@ def subtract(self, regions, whole_region: bool = False, :param whole_region: Subtract the whole region, not partially, defaults to False :type whole_region: bool, default to False - :param merge: Merging the regions before subtracting - :type merge: bool, default to True + :param strandness: Define whether strandness is considered. + :type strandness: bool :param exact: Only regions which match exactly with a given region are subtracted. If True, whole_region and merge are completely ignored and the returned GRegions is sorted @@ -562,7 +562,107 @@ def subtract(self, regions, whole_region: bool = False, regions ---------- ---- Result ------- ------ """ - pass + assert isinstance(regions, GRegionsTree) + + def remain_interval(seq, begin, end, interval, res): + gr = GRegion(sequence=seq, start=begin, end=end, + name=interval.data.name, score=interval.data.score, + orientation=interval.data.orientation) + remain = Interval(begin, end, gr) + res.elements[seq].add(remain) + + res = GRegionsTree(name=self.name) + for seq in self.elements.keys(): + for interval in self.elements[seq]: + # Check if exact match is required + if exact: + # If interval not found in regions, add it to result + if interval not in regions.elements[seq]: + res.elements[seq].add(interval) + elif whole_region: + if any(regions.elements[seq].overlap( + interval.begin, interval.end)): + continue + elif strandness: + for r in regions.elements[seq].overlap( + interval.begin, interval.end): + # interval ----- + # r -------------- + # remain + if interval.begin > r.begin and \ + interval.end < r.end and \ + interval.data.orientation ==\ + r.data.orientation: + continue + # interval -------------- + # r ----- + # remain ---- ----- + elif interval.begin < r.begin and \ + interval.end > r.end and \ + interval.data.orientation ==\ + r.data.orientation: + remain_interval(seq, interval.begin, r.begin, + interval, res) + remain_interval(seq, r.end, interval.end, + interval, res) + # interval ------- + # r ------- + # remain ---- + elif interval.begin < r.begin and \ + interval.end <= r.end and \ + interval.data.orientation ==\ + r.data.orientation: + remain_interval(seq, interval.begin, r.begin, + interval, res) + # interval ------- + # r ------- + # remain --- + elif interval.begin >= r.begin and \ + interval.end > r.end and \ + interval.data.orientation ==\ + r.data.orientation: + remain_interval(seq, r.end, interval.end, + interval, res) + else: + res.elements[seq].add(interval) + else: + for r in regions.elements[seq].overlap( + interval.begin, interval.end): + # interval ----- + # r -------------- + # remain + if interval.begin > r.begin and \ + interval.end < r.end: + continue + # interval -------------- + # r ----- + # remain ---- ----- + elif interval.begin < r.begin and \ + interval.end > r.end: + remain_interval(seq, interval.begin, r.begin, + interval, res) + remain_interval(seq, r.end, interval.end, + interval, res) + # interval ------- + # r ------- + # remain ---- + elif interval.begin < r.begin and \ + interval.end <= r.end: + remain_interval(seq, interval.begin, r.begin, + interval, res) + # interval ------- + # r ------- + # remain --- + elif interval.begin >= r.begin and \ + interval.end > r.end: + remain_interval(seq, r.end, interval.end, + interval, res) + else: + res.elements[seq].add(interval) + if inplace: + self.elements = res.elements + else: + return res def get_GSequences(self, FASTA_file): """Return a GSequences object according to the loci on the given diff --git a/tests/test_gregionstree.py b/tests/test_gregionstree.py index cb8b2cc..7d0cfb2 100644 --- a/tests/test_gregionstree.py +++ b/tests/test_gregionstree.py @@ -100,16 +100,6 @@ def test_intersect(self): intersect = regions1.intersect(regions2, mode='COMP_INCL') self.assertEqual(len(intersect), 0) - # def test_intersect_array(self): - # regions1 = GRegionsTree(name="test") - # regions1.load(filename=os.path.join(script_path, - # "test_files/bed/example.bed")) - # regions2 = GRegionsTree(name="test") - # regions2.load(filename=os.path.join(script_path, - # "test_files/bed/example2.bed")) - # intersect = regions1.intersect_array(regions2) - # self.assertEqual(len(intersect), 4) - def test_merge(self): regions = GRegionsTree(name="test") regions.load(filename=os.path.join(script_path, @@ -134,44 +124,55 @@ def test_remove_duplicates(self): regions.remove_duplicates() self.assertEqual(len(regions), 4) - # def test_sort(self): - # regions = GRegionsTree(name="test") - # regions.load(filename=os.path.join(script_path, - # "test_files/bed/example4.bed")) - # regions.sort() - ## Duplicate interval should be kept - # self.assertEqual(len(regions), 6) - # self.assertEqual(regions.get_sequences(), - # ["chr1", "chr1", "chr1", "chr2", "chr2", "chr2"]) - - # def test_sampling(self): - # regions = load_BED(filename=os.path.join(script_path, - # "test_files/bed/example4.bed")) - # sampling = regions.sampling(size=3) - # self.assertEqual(len(sampling), 3) - - # def test_subtract(self): - # regions1 = GRegionsTree(name="test") - # regions1.load(filename=os.path.join(script_path, - # "test_files/bed/example.bed")) - # regions2 = GRegionsTree(name="test") - # regions2.load(filename=os.path.join(script_path, - # "test_files/bed/example2.bed")) - # regions1.subtract(regions2) - # self.assertEqual(len(regions1), 4) - # self.assertEqual(len(regions1[0]), 500) - # self.assertEqual(len(regions1[1]), 500) - # self.assertEqual(len(regions1[2]), 500) - # self.assertEqual(len(regions1[3]), 500) - - # # regions1 = GRegionsTree(name="test") - # # regions1.load(filename=os.path.join(script_path, - # # "test_files/bed/example.bed")) - # # regions2 = GRegionsTree(name="test") - # # regions2.load(filename=os.path.join(script_path, - # # "test_files/bed/example2.bed")) - # # regions1.subtract(regions2, whole_region=True) - # # self.assertEqual(len(regions1[0]), 0) + def test_sort(self): + regions = GRegionsTree(name="test") + regions.load(filename=os.path.join(script_path, + "test_files/bed/example4.bed")) + regions.sort() + self.assertEqual(len(regions), 4) + self.assertEqual(sorted(regions.get_sequences()), + ["chr1", "chr1", "chr2", "chr2"]) + + def test_sampling(self): + regions = GRegionsTree( + load=os.path.join(script_path, "test_files/bed/example4.bed")) + sampling = regions.sampling(size=3) + self.assertEqual(len(sampling), 3) + + def test_subtract(self): + regions1 = GRegionsTree(name="test") + regions1.load(filename=os.path.join(script_path, + "test_files/bed/example.bed")) + regions2 = GRegionsTree(name="test") + regions2.load(filename=os.path.join(script_path, + "test_files/bed/example2.bed")) + regions1.subtract(regions2) + for r in regions1: + print(r) + self.assertEqual(len(regions1), 4) + self.assertEqual(len(regions1[0]), 500) + self.assertEqual(len(regions1[1]), 500) + self.assertEqual(len(regions1[2]), 500) + self.assertEqual(len(regions1[3]), 500) + regions1 = GRegionsTree(name="test") + regions1.load(filename=os.path.join(script_path, + "test_files/bed/example.bed")) + regions1.subtract(regions2, strandness=True) + for r in regions1: + print(r) + self.assertEqual(len(regions1), 4) + self.assertEqual(len(regions1[0]), 500) + self.assertEqual(len(regions1[1]), 1000) + self.assertEqual(len(regions1[2]), 1000) + self.assertEqual(len(regions1[3]), 500) + # regions1 = GRegionsTree(name="test") + # regions1.load(filename=os.path.join(script_path, + # "test_files/bed/example.bed")) + # regions2 = GRegionsTree(name="test") + # regions2.load(filename=os.path.join(script_path, + # "test_files/bed/example2.bed")) + # regions1.subtract(regions2, whole_region=True) + # self.assertEqual(len(regions1[0]), 0) # def test_total_coverage(self): # regions1 = GRegionsTree(name="test")