diff --git a/pelican/generators.py b/pelican/generators.py index e18531bec..4403ee5c0 100644 --- a/pelican/generators.py +++ b/pelican/generators.py @@ -123,10 +123,13 @@ def _include_path(self, path, extensions=None): if any(fnmatch.fnmatch(basename, ignore) for ignore in ignores): return False - ext = os.path.splitext(basename)[1][1:] - if extensions is False or ext in extensions: + if extensions is False: return True + for ext in extensions: + if basename.endswith(f'.{ext}'): + return True + return False def get_files(self, paths, exclude=[], extensions=None): diff --git a/pelican/readers.py b/pelican/readers.py index b42fa4f85..2e9ad24c8 100644 --- a/pelican/readers.py +++ b/pelican/readers.py @@ -1,8 +1,10 @@ import datetime import logging +import operator import os import re -from collections import OrderedDict +from collections import OrderedDict, defaultdict +from functools import reduce from html import escape from html.parser import HTMLParser from io import StringIO @@ -496,8 +498,8 @@ class Readers(FileStampDataCacher): def __init__(self, settings=None, cache_name=''): self.settings = settings or {} - self.readers = {} - self.reader_classes = {} + self.readers = ReaderTree() + self.reader_classes = ReaderTree() for cls in [BaseReader] + BaseReader.__subclasses__(): if not cls.enabled: @@ -542,8 +544,7 @@ def read_file(self, base_path, path, content_class=Page, fmt=None, source_path, content_class.__name__) if not fmt: - _, ext = os.path.splitext(os.path.basename(path)) - fmt = ext[1:] + fmt = self.readers.get_format(path) if fmt not in self.readers: raise TypeError( @@ -746,3 +747,159 @@ def parse_path_metadata(source_path, settings=None, process=None): v = process(k, v) metadata[k] = v return metadata + + +class ReaderTree(): + + def __init__(self): + self.tree_dd = ReaderTree._rec_dd() + + def __str__(self): + return str(ReaderTree._rec_dd_to_dict(self.tree_dd)) + + def __iter__(self): + for key in ReaderTree._rec_get_next_key(self.tree_dd): + yield key + + def __setitem__(self, key, value): + components = reversed(key.split('.')) + reduce(operator.getitem, components, self.tree_dd)[''] = value + + def __getitem__(self, key): + components = reversed(key.split('.')) + value = reduce(operator.getitem, components, self.tree_dd) + if value: + return value[''] + else: + raise KeyError + + def __delitem__(self, key): + value = ReaderTree._rec_del_item(self.tree_dd, key) + if not value: + raise KeyError + + def __contains__(self, item): + try: + self[item] + return True + except KeyError: + return False + + def __len__(self): + return len(list(self.keys())) + + def keys(self): + return self.__iter__() + + def values(self): + for value in ReaderTree._rec_get_next_value(self.tree_dd): + yield value + + def items(self): + return zip(self.keys(), self.values()) + + def get(self, key): + return self[key] + + def setdefault(self, key, value): + if key in self: + return self[key] + else: + self[key] = value + return value + + def clear(self): + self.tree_dd.clear() + + def pop(self, key, default=None): + if key in self: + value = self[key] + del self[key] + return value + elif default: + return default + else: + raise KeyError + + def copy(self): + return self.tree_dd.copy() + + def update(self, d): + for key, value in d.items(): + self[key] = value + + def get_format(self, filename): + try: + ext = ReaderTree._rec_get_fmt_from_filename(self.tree_dd, filename) + return ext[1:] + except TypeError: + return '' + + def has_reader(self, filename): + fmt = self.get_format(filename) + return fmt in self + + def as_dict(self): + return ReaderTree._rec_dd_to_dict(self.tree_dd) + + @staticmethod + def _rec_dd(): + return defaultdict(ReaderTree._rec_dd) + + @staticmethod + def _rec_dd_to_dict(dd): + d = dict(dd) + + for key, value in d.items(): + if type(value) == defaultdict: + d[key] = ReaderTree._rec_dd_to_dict(value) + + return d + + @staticmethod + def _rec_get_next_key(d): + for key in d: + if key != '': + if '' in d[key]: + yield key + if type(d[key]) == defaultdict: + for component in ReaderTree._rec_get_next_key(d[key]): + yield '.'.join([component, key]) + + @staticmethod + def _rec_get_next_value(d): + for key, value in d.items(): + if key == '': + yield value + else: + if type(d[key]) == defaultdict: + yield from ReaderTree._rec_get_next_value(d[key]) + + @staticmethod + def _rec_del_item(d, intended_key): + if intended_key in d: + value = d[intended_key][''] + del d[intended_key][''] + return value + else: + for key in d: + if type(d[key]) == defaultdict: + ReaderTree._rec_del_item(d[key], intended_key) + + return None + + @staticmethod + def _rec_get_fmt_from_filename(d, filename): + if '.' in filename: + file, ext = os.path.splitext(filename) + fmt = ext[1:] if ext else ext + + if fmt in d: + next_component = ReaderTree._rec_get_fmt_from_filename(d[fmt], file) + return '.'.join([next_component, fmt]) + elif '' in d: + return '' + else: + raise TypeError('No reader found for file.') + else: + return '' diff --git a/pelican/tests/test_generators.py b/pelican/tests/test_generators.py index 1bc8aff0c..327935d05 100644 --- a/pelican/tests/test_generators.py +++ b/pelican/tests/test_generators.py @@ -41,6 +41,9 @@ def test_include_path(self): ignored_file = os.path.join(CUR_DIR, 'content', 'ignored1.rst') self.assertFalse(include_path(ignored_file)) + compound_file = os.path.join(CUR_DIR, 'content', 'compound.md.html') + self.assertTrue(include_path(compound_file, extensions=('md.html',))) + def test_get_files_exclude(self): """Test that Generator.get_files() properly excludes directories. """ diff --git a/pelican/tests/test_readers.py b/pelican/tests/test_readers.py index 753a353d0..0d55d5dd2 100644 --- a/pelican/tests/test_readers.py +++ b/pelican/tests/test_readers.py @@ -1,5 +1,5 @@ import os -from unittest.mock import patch +from unittest.mock import Mock, patch from pelican import readers from pelican.tests.support import get_settings, unittest @@ -76,6 +76,18 @@ def test_readfile_unknown_extension(self): with self.assertRaises(TypeError): self.read_file(path='article_with_metadata.unknownextension') + with self.assertRaises(TypeError): + self.read_file(path='article_with.compound.extension') + + def test_readfile_compound_extension(self): + CompoundReader = Mock() + + # throws type error b/c of mock + with self.assertRaises(TypeError): + self.read_file(path='article_with.compound.extension', + READERS={'compound.extension': CompoundReader}) + CompoundReader.read.assert_called_with('article_with.compound.extension') + def test_readfile_path_metadata_implicit_dates(self): test_file = 'article_with_metadata_implicit_dates.html' page = self.read_file(path=test_file, DEFAULT_DATE='fs') @@ -918,3 +930,81 @@ def test_article_with_inline_svg(self): 'title': 'Article with an inline SVG', } self.assertDictHasSubset(page.metadata, expected) + + +class ReaderTreeTest(unittest.TestCase): + + def setUp(self): + + readers_and_exts = { + 'BaseReader': ['static'], + 'RstReader': ['rst'], + 'HtmlReader': ['htm', 'html'], + 'MDReader': ['md', 'mk', 'mkdown', 'mkd'], + 'MDeepReader': ['md.html'], + 'FooReader': ['foo.bar.baz.yaz'] + } + + self.reader_classes = readers.ReaderTree() + + for reader, exts in readers_and_exts.items(): + for ext in exts: + self.reader_classes[ext] = reader + + def test_correct_mapping_generated(self): + expected_mapping = { + 'static': {'': 'BaseReader'}, + 'rst': {'': 'RstReader'}, + 'htm': {'': 'HtmlReader'}, + 'html': { + '': 'HtmlReader', + 'md': {'': 'MDeepReader'} + }, + 'md': {'': 'MDReader'}, + 'mk': {'': 'MDReader'}, + 'mkdown': {'': 'MDReader'}, + 'mkd': {'': 'MDReader'}, + 'yaz': { + 'baz': { + 'bar': { + 'foo': {'': 'FooReader'}}}}} + + self.assertEqual(expected_mapping, self.reader_classes.as_dict()) + + def test_containment(self): + self.assertTrue('md.html' in self.reader_classes) + self.assertTrue('html' in self.reader_classes) + self.assertFalse('txt' in self.reader_classes) + + def test_deletion(self): + self.assertTrue('rst' in self.reader_classes) + del self.reader_classes['rst'] + self.assertFalse('rst' in self.reader_classes) + + def test_update(self): + self.reader_classes.update({ + 'new.ext': 'NewExtReader', + 'txt': 'TxtReader' + }) + self.assertEqual(self.reader_classes['new.ext'], 'NewExtReader') + self.assertEqual(self.reader_classes['txt'], 'TxtReader') + + def test_get_format(self): + html_ext = self.reader_classes.get_format('text.html') + md_ext = self.reader_classes.get_format('another.md') + compound_ext = self.reader_classes.get_format('dots.compound.md.html') + no_ext = self.reader_classes.get_format('no_extension') + bar_ext = self.reader_classes.get_format('file.bar') + + self.assertEqual(html_ext, 'html') + self.assertEqual(md_ext, 'md') + self.assertEqual(compound_ext, 'md.html') + self.assertEqual(no_ext, '') + self.assertEqual(bar_ext, '') + + def test_has_reader(self): + has_reader = self.reader_classes.has_reader + self.assertTrue(has_reader('text.html')) + self.assertFalse(has_reader('no_ext')) + print(has_reader('bad_ext.bar')) + self.assertFalse(has_reader('bad_ext.bar'))