Skip to content

Commit

Permalink
Merge pull request #89 from vikstrous/mirrors
Browse files Browse the repository at this point in the history
Allow users to set custom mirrors
  • Loading branch information
rnhmjoj authored Sep 3, 2016
2 parents a1cba67 + bba0f41 commit 68e61af
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ transmission = false

; use colored output
colors = true

; the pirate bay mirror(s) to use:
; one or more space separated URLs
mirror = http://thepiratebay.org
```

Note:
Expand Down
47 changes: 25 additions & 22 deletions pirate/pirate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def parse_config_file(text):
config.set('Misc', 'openCommand', '')
config.set('Misc', 'transmission', 'false')
config.set('Misc', 'colors', 'true')
config.set('Misc', 'mirror', pirate.data.default_mirror)

config.read_string(text)

Expand All @@ -54,16 +55,12 @@ def parse_config_file(text):

def load_config():
# user-defined config files
main = expandvars('$XDG_CONFIG_HOME/pirate-get')
alt = expanduser('~/.config/pirate-get')
config_home = os.getenv('XDG_CONFIG_HOME', '~/.config')
config = expanduser(os.path.join(config_home, 'pirate-get'))

# read config file
if os.path.isfile(main):
with open(main) as f:
return parse_config_file(f.read())

if os.path.isfile(alt):
with open(alt) as f:
if os.path.isfile(config):
with open(config) as f:
return parse_config_file(f.read())

return parse_config_file("")
Expand Down Expand Up @@ -173,6 +170,9 @@ def parse_args(args_in):
parser.add_argument('--disable-colors', dest='color',
action='store_false',
help='disable colored output')
parser.add_argument('-m', '--mirror',
type=str, nargs='+',
help='the pirate bay mirror(s) to use')
args = parser.parse_args(args_in)

return args
Expand Down Expand Up @@ -207,6 +207,9 @@ def combine_configs(config, args):
if not args.save_directory:
args.save_directory = config.get('Save', 'directory')

if not args.mirror:
args.mirror = config.get('Misc', 'mirror').split()

args.transmission_command = ['transmission-remote']
if args.port:
args.transmission_command.append(args.port)
Expand All @@ -228,16 +231,16 @@ def combine_configs(config, args):
return args


def connect_mirror(mirror, printer, pages, category, sort, action, search):
def connect_mirror(mirror, printer, args):
try:
printer.print('Trying', mirror, end='... ')
results = pirate.torrent.remote(
printer=printer,
pages=pages,
category=pirate.torrent.parse_category(printer, category),
sort=pirate.torrent.parse_sort(printer, sort),
mode=action,
terms=search,
pages=args.pages,
category=pirate.torrent.parse_category(printer, args.category),
sort=pirate.torrent.parse_sort(printer, args.sort),
mode=args.action,
terms=args.search,
mirror=mirror)
except (urllib.error.URLError, socket.timeout, IOError, ValueError):
printer.print('Failed', color='WARN')
Expand All @@ -247,11 +250,12 @@ def connect_mirror(mirror, printer, pages, category, sort, action, search):
return results, mirror


def search_mirrors(printer, *args):
# try official site
result = connect_mirror(pirate.data.default_mirror, printer, *args)
if result:
return result
def search_mirrors(printer, args):
# try default or user mirrors
for mirror in args.mirror:
result = connect_mirror(mirror, printer, args)
if result:
return result

# download mirror list
try:
Expand All @@ -271,7 +275,7 @@ def search_mirrors(printer, *args):
for mirror in mirrors:
if mirror in pirate.data.blacklist:
continue
result = connect_mirror(mirror, printer, *args)
result = connect_mirror(mirror, printer, args)
if result:
return result
else:
Expand Down Expand Up @@ -312,8 +316,7 @@ def pirate_main(args):
results = pirate.local.search(args.database, args.search)
elif args.source == 'tpb':
try:
results, site = search_mirrors(printer, args.pages, args.category,
args.sort, args.action, args.search)
results, site = search_mirrors(printer, args)
except IOError as e:
printer.print(e.args[0] + ' :( ', color='ERROR')
if len(e.args) > 1:
Expand Down
9 changes: 6 additions & 3 deletions tests/test_pirate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import socket
import unittest
import subprocess
from argparse import Namespace
from unittest import mock
from unittest.mock import patch, call, MagicMock

Expand Down Expand Up @@ -147,7 +148,9 @@ def test_parse_args(self):
self.assertEqual(test[2][option], value)

def test_search_mirrors(self):
pages, category, sort, action, search = (1, 100, 10, 'browse', [])
args = Namespace(pages=1, category=100, sort=10,
action='browse', search=[],
mirror=[pirate.data.default_mirror])
class MockResponse():
readlines = mock.MagicMock(return_value=[x.encode('utf-8') for x in ['', '', '', 'https://example.com']])
info = mock.MagicMock()
Expand All @@ -156,12 +159,12 @@ class MockResponse():
printer = MagicMock(Printer)
with patch('urllib.request.urlopen', return_value=response_obj) as urlopen:
with patch('pirate.torrent.remote', return_value=[]) as remote:
results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search)
results, mirror = pirate.pirate.search_mirrors(printer, args)
self.assertEqual(results, [])
self.assertEqual(mirror, pirate.data.default_mirror)
remote.assert_called_once_with(printer=printer, pages=1, category=100, sort=10, mode='browse', terms=[], mirror=pirate.data.default_mirror)
with patch('pirate.torrent.remote', side_effect=[socket.timeout, []]) as remote:
results, mirror = pirate.pirate.search_mirrors(printer, pages, category, sort, action, search)
results, mirror = pirate.pirate.search_mirrors(printer, args)
self.assertEqual(results, [])
self.assertEqual(mirror, 'https://example.com')
remote.assert_has_calls([
Expand Down

0 comments on commit 68e61af

Please sign in to comment.