Skip to content
This repository has been archived by the owner on Oct 31, 2024. It is now read-only.

Commit

Permalink
Fixture (and Session, Flash) fix: add _safe_local-property which is…
Browse files Browse the repository at this point in the history
… thread safe and reset on every request
  • Loading branch information
valq7711 committed Sep 1, 2021
1 parent 826fb1a commit 36aab84
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 109 deletions.
10 changes: 8 additions & 2 deletions .github/workflows/run_test.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
name: ombott-test
name: master-test

on: [push, pull_request]
on:
push:
branches:
- master
pull_request:
branches:
- master

jobs:
build:
Expand Down
1 change: 0 additions & 1 deletion py4web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def _maybe_gevent():
Translator, # from pluralize
Session,
Cache,
Current,
Flash,
user_in, # additional fixtures
URL, # custom helper
Expand Down
169 changes: 83 additions & 86 deletions py4web/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@
os.environ["PY4WEB_PATH"] = str(pathlib.Path(__file__).resolve().parents[1])


# hold all framework hooks in one place
# NOTE: `after_request` hooks are not currently used
_REQUEST_HOOKS = types.SimpleNamespace(before=set())


def _before_request(*args, **kw):
[h(*args, **kw) for h in _REQUEST_HOOKS.before]


bottle.default_app().add_hook("before_request", _before_request)




def module2filename(module):
filename = os.path.join(*module.split(".")[1:])
filename = (
Expand Down Expand Up @@ -298,6 +312,29 @@ def dumps(obj, sort_keys=True, indent=2):


class Fixture:
__request_master_ctx__ = threading.local()

@classmethod
def __init_request_ctx__(cls):
cls.__request_master_ctx__.request_ctx = dict()

@classmethod
def __mount_local__(cls, self, storage):
cls.__request_master_ctx__.request_ctx[self] = storage

@property
def _safe_local(self):
try:
ret = self.__request_master_ctx__.request_ctx[self]
except KeyError as err:
msg = 'py4web hint: check @action.uses() for the missing fixture {}'.format(self)
raise RuntimeError(msg) from err
return ret

@_safe_local.setter
def _safe_local(self, storage):
self.__mount_local__(self, storage)

def on_request(self):
pass # called when a request arrives

Expand All @@ -316,6 +353,9 @@ def transform(
return output


_REQUEST_HOOKS.before.add(Fixture.__init_request_ctx__)


class Translator(pluralize.Translator, Fixture):
def on_request(self):
self.select(request.headers.get("Accept-Language", "en"))
Expand Down Expand Up @@ -381,57 +421,6 @@ def field_copy(self):
ICECUBE = {}


#########################################################################################
# Current Fixture
#########################################################################################


class NotInCurrent(Exception):
"""This exception is raised when one tries to access a request-local
object but one is not in a request-local context."""

pass


class Current(Fixture):
"""
This fixture gives access to a request-local object, that is cleaned
after each request. Note that the object is thread-local; if the
request processing uses multiple threads, this will not be accessible.
"""

def __init__(self):
self._local = threading.local()
self.local = None

def on_request(self):
self._local = self._local
self.local.data = {}

def finalize(self):
self.local = None

def __setitem__(self, key, value):
if self.local is None:
raise NotInCurrent()
self.local.data[key] = value

def __getitem__(self, key):
if self.local is None:
raise NotInCurrent()
return self.local.data[key]

def __delitem__(self, key):
if self.local is None:
raise NotInCurrent()
del self.local.data[key]

def get(self, key, default=None):
if self.local is None:
raise NotInCurrent()
return self.local.data.get(key, default)


#########################################################################################
# Flash Fixture
#########################################################################################
Expand All @@ -454,9 +443,13 @@ def index():
# this essential makes flash a singleton
# necessary because auth defines its own flash
# possible because flash does not depend on the app
local = threading.local()

@property
def local(self):
return self._safe_local

def on_request(self):
self._safe_local = types.SimpleNamespace()
# when a new request arrives we look for a flash message in the cookie
flash = request.get_cookie("py4web-flash")
if flash:
Expand Down Expand Up @@ -587,6 +580,10 @@ class Session(Fixture):
# the actual value is loaded from a file
SECRET = None

@property
def local(self):
return self._safe_local

def __init__(
self,
secret=None,
Expand All @@ -611,45 +608,47 @@ def __init__(
self.__prerequisites__ = [storage]
if hasattr(storage, "__prerequisites__"):
self.__prerequisites__ = storage.__prerequisites__
self._local = threading.local()
self.local = None # We initialize this per-request.

def initialize(self, app_name="unknown", data=None, changed=False, secure=False):
self.local = self._local
self.local.changed = changed
self.local.data = data or {}
self.local.session_cookie_name = "%s_session" % app_name
self.local.secure = secure
self._safe_local = types.SimpleNamespace()
local = self.local
local.changed = changed
local.data = data or {}
local.session_cookie_name = "%s_session" % app_name
local.secure = secure

def load(self):
self.initialize(
app_name=request.app_name,
changed=False,
secure=request.url.startswith("https"),
)
raw_token = request.get_cookie(
self.local.session_cookie_name
) or request.query.get("_session_token")
if not raw_token and request.method in ("POST", "PUT", "DELETE"):
raw_token = (request.forms and request.forms.get("_session_token")) or (
request.json and request.json.get("_session_token")
self_local = self.local
raw_token = (
request.get_cookie(self_local.session_cookie_name)
or request.query.get("_session_token")
)
if not raw_token and request.method in {"POST", "PUT", "DELETE", "PATCH"}:
raw_token = (
request.forms and request.forms.get("_session_token")
or request.json and request.json.get("_session_token")
)
if raw_token:
token_data = raw_token.encode()
try:
if self.storage:
json_data = self.storage.get(token_data)
if json_data:
self.local.data = json.loads(json_data)
self_local.data = json.loads(json_data)
else:
self.local.data = jwt.decode(
self_local.data = jwt.decode(
token_data, self.secret, algorithms=[self.algorithm]
)
if self.expiration is not None and self.storage is None:
assert self.local.data["timestamp"] > time.time() - int(
self.expiration
assert (
self_local.data["timestamp"] > time.time() - int(self.expiration)
)
assert self.get_data().get("secure") == self.local.secure
assert self.get_data().get("secure") == self_local.secure
except Exception:
pass
if not "uuid" in self.get_data():
Expand All @@ -659,22 +658,23 @@ def get_data(self):
return getattr(self.local, "data", {})

def save(self):
self.local.data["timestamp"] = time.time()
self_local = self.local
self_local.data["timestamp"] = time.time()
if self.storage:
cookie_data = self.local.data["uuid"]
self.storage.set(cookie_data, json.dumps(self.local.data), self.expiration)
cookie_data = self_local.data["uuid"]
self.storage.set(cookie_data, json.dumps(self_local.data), self.expiration)
else:
cookie_data = jwt.encode(
self.local.data, self.secret, algorithm=self.algorithm
self_local.data, self.secret, algorithm=self.algorithm
)
if isinstance(cookie_data, bytes):
cookie_data = cookie_data.decode()

response.set_cookie(
self.local.session_cookie_name,
self_local.session_cookie_name,
cookie_data,
path="/",
secure=self.local.secure,
secure=self_local.secure,
same_site=self.same_site,
)

Expand All @@ -697,15 +697,15 @@ def keys(self):
return self.get_data().keys()

def __iter__(self):
for item in self.get_data().items():
yield item
yield from self.get_data().items()

def clear(self):
"""Produces a brand-new session."""
self.local.changed = True
self.local.data.clear()
self.local.data["uuid"] = str(uuid.uuid1())
self.local.data["secure"] = self.local.secure
self_local = self.local
self_local.changed = True
self_local.data.clear()
self_local.data["uuid"] = str(uuid.uuid1())
self_local.data["secure"] = self_local.secure

def on_request(self):
self.load()
Expand All @@ -718,9 +718,6 @@ def on_success(self, status):
if self.local.changed:
self.save()

def finalize(self):
self.local = None # To prevent leakage


#########################################################################################
# The URL Helper
Expand Down Expand Up @@ -1200,7 +1197,7 @@ def hook(*a, **k):
## APP_WATCH tasks, if used by any app
try_app_watch_tasks()

bottle.default_app().add_hook("before_request", hook)
_REQUEST_HOOKS.before.add(hook)

@staticmethod
def clear_routes(app_name=None):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
def index():
db.thing.insert(name="test")
session["number"] = session.get("number", 0) + 1

# test copying Field ThreadSafe attr
db.thing.name.default = "test_clone"
field_clone = copy.copy(db.thing.name)
Expand Down Expand Up @@ -65,6 +65,7 @@ def test_action(self):

def test_local(self):
# for test coverage
Session.__init_request_ctx__() # mimic before_request-hook
index()

def test_error_page(self):
Expand Down
Loading

0 comments on commit 36aab84

Please sign in to comment.