Skip to content

Commit

Permalink
Enable CSRF protection globally by default
Browse files Browse the repository at this point in the history
Enable Pyramid's `config.set_default_csrf_options(require_csrf=True)`
which causes it to require a valid CSRF token for all requests with a
request method that is *not* one of `GET`, `HEAD`, `OPTIONS` or `TRACE`.

The CSRF token must be in a csrf_token POST parameter or an
X-CSRF-Token header, and must match the CSRF token stored in the signed
session cookie.

It also checks that the request's `Referer` (if any) is the current
host.

See:

* https://docs.pylonsproject.org/projects/pyramid/en/latest/narr/security.html#checking-csrf-tokens-automatically
* https://docs.pylonsproject.org/projects/pyramid/en/latest/api/config.html#pyramid.config.Configurator.set_default_csrf_options

This is a safer default. The current implementation requires all views
receiving form submissions to use a Colander schema that's a subclass of
`CSRFSchema`. It's too easy to forget to add CSRF protection to a form
if it doesn't use Colander (for example: perhaps there are no parameters
to be validated) or if it has a schema that doesn't subclass
`CSRFSchema`. Even if the view's schema *does* sublass `CSRFSchema`, if
it wants to have a `validate()` method it must remember to call
`super().validate()` or it'll disable `CSRFSchema`'s CSRF protection.

This commit removes the CSRF protection code form `CSRFSchema` (that
schema is now only used to *serialize* the CSRF tokens into the forms,
but doesn't do any CSRF validation at *deserialization* time) and
instead enables Pyramid's global CSRF protection option.

CSRF protection can be disabled for individual views by passing
`require_csrf=False` to `@view_config`. This has been added to h's
custom `@api_config` decorator so that CSRF protection is disabled for
all API endpoints.
  • Loading branch information
seanh committed Feb 7, 2025
1 parent 3afebfd commit 78d79ed
Show file tree
Hide file tree
Showing 14 changed files with 41 additions and 84 deletions.
4 changes: 0 additions & 4 deletions h/accounts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ class EmailChangeSchema(CSRFSchema):
password = password_node(title=_("Confirm password"), hide_until_form_active=True)

def validator(self, node, value):
super().validator(node, value)
exc = colander.Invalid(node)
request = node.bindings["request"]
svc = request.find_service(name="user_password")
Expand Down Expand Up @@ -229,7 +228,6 @@ class PasswordChangeSchema(CSRFSchema):
)

def validator(self, node, value): # pragma: no cover
super().validator(node, value)
exc = colander.Invalid(node)
request = node.bindings["request"]
svc = request.find_service(name="user_password")
Expand All @@ -249,8 +247,6 @@ class DeleteAccountSchema(CSRFSchema):
password = password_node(title=_("Confirm password"))

def validator(self, node, value):
super().validator(node, value)

request = node.bindings["request"]
svc = request.find_service(name="user_password")

Expand Down
2 changes: 2 additions & 0 deletions h/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def create_app(_global_config, **settings): # pragma: no cover


def includeme(config): # pragma: no cover
config.set_default_csrf_options(require_csrf=True)

config.scan("h.subscribers")

config.add_tween("h.tweens.conditional_http_tween_factory", under=EXCVIEW)
Expand Down
2 changes: 0 additions & 2 deletions h/schemas/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ class CreateAuthClientSchema(CSRFSchema):
)

def validator(self, node, value):
super().validator(node, value)

grant_type = value.get("grant_type")
redirect_url = value.get("redirect_url")

Expand Down
44 changes: 35 additions & 9 deletions h/schemas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import deform
import jsonschema
from pyramid import httpexceptions
from pyramid.csrf import check_csrf_token, get_csrf_token
from pyramid.csrf import get_csrf_token


@colander.deferred
Expand All @@ -21,23 +21,49 @@ class ValidationError(httpexceptions.HTTPBadRequest):

class CSRFSchema(colander.Schema):
"""
A CSRFSchema backward-compatible with the one from the hem module.
Add a hidden CSRF token to forms when seralized using Deform.
Unlike hem, this doesn't require that the csrf_token appear in the
serialized appstruct.
This is intended as a base class for other schemas to inherit from if the
schema's form needs a CSRF token (by default all form submissions do need a
CSRF token).
This schema *does not* implement CSRF verification when receiving requests.
That's enabled globally for non-GET requests by
config.set_default_csrf_options(require_csrf=True).
"""

csrf_token = colander.SchemaNode(
colander.String(),
widget=deform.widget.HiddenWidget(),
# When serializing (i.e. when rendering a form) if there's no
# csrf_token then call deferred_csrf_token() to get one.
default=deferred_csrf_token,
missing=None,
# Allow data with no "csrf_token" field to be *deserialized* successfully
# (the deserialized data will contain no "csrf_token" field.)
#
# CSRF protection isn't provided by this schema, it's provided by
# Pyramid's config.set_default_csrf_options(require_csrf=True).
#
# Nonetheless, without a `missing` value, when deserializing any
# subclass of this schema Colander would require a csrf_token field to
# be present in the data (even if this schema doesn't actually check
# that the token is valid).
#
# In production any request missing a CSRF token would be rejected by
# Pyramid's CSRF protection before even reaching schema
# deserialization. So by the time we get to schema deserialization
# there must be a CSRF token and this `missing` value is seemingly
# unnecessary.
#
# However:
#
# 1. The CSRF token may be in an X-CSRF-Token header rather than in a
# POST param.
# 2. Unittests for schemas often don't set a CSRF token and would fail
# if this `missing` value wasn't here.
missing=colander.drop,
)

def validator(self, node, _value):
request = node.bindings["request"]
check_csrf_token(request)


class JSONSchema:
"""
Expand Down
2 changes: 0 additions & 2 deletions h/schemas/forms/accounts/forgot_password.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ class ForgotPasswordSchema(CSRFSchema):
)

def validator(self, node, value):
super().validator(node, value)

request = node.bindings["request"]
email = value.get("email")
user = models.User.get_by_email(request.db, email, request.default_authority)
Expand Down
2 changes: 0 additions & 2 deletions h/schemas/forms/accounts/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class LoginSchema(CSRFSchema):
)

def validator(self, node, value):
super().validator(node, value)

request = node.bindings["request"]
username = value.get("username")
password = value.get("password")
Expand Down
1 change: 0 additions & 1 deletion h/schemas/forms/admin/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,5 +219,4 @@ def __init__(self, *args):
)

def validator(self, node, value):
super().validator(node, value)
username_validator(node, value)
1 change: 1 addition & 0 deletions h/views/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def post(self):
request_param="response_mode=web_message",
is_authenticated=True,
renderer="h:templates/oauth/authorize_web_message.html.jinja2",
require_csrf=False,
)
def post_web_message(self):
"""
Expand Down
1 change: 1 addition & 0 deletions h/views/api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def add_api_view( # noqa: PLR0913
`route_name` must be specified.
:param dict **settings: Arguments to pass on to ``config.add_view``
"""
settings.setdefault("require_csrf", False)
settings.setdefault("renderer", "json")
settings.setdefault("decorator", (cors_policy, version_media_type_header(subtype)))

Expand Down
5 changes: 1 addition & 4 deletions tests/functional/h/views/admin/permissions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def test_accessible_by_staff(self, app, url, accessible):

assert res.status_code == 200 if accessible else 404

GROUP_PAGES = (
("POST", "/admin/groups/delete/{pubid}", 302),
("GET", "/admin/groups/{pubid}", 200),
)
GROUP_PAGES = (("GET", "/admin/groups/{pubid}", 200),)

@pytest.mark.usefixtures("with_logged_in_admin")
@pytest.mark.parametrize("method,url_template,success_code", GROUP_PAGES)
Expand Down
17 changes: 1 addition & 16 deletions tests/unit/h/accounts/schemas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import colander
import pytest
from pyramid.exceptions import BadCSRFToken

from h.accounts import schemas
from h.services.user_password import UserPasswordService
Expand Down Expand Up @@ -156,9 +155,7 @@ def test_it_validates_with_valid_payload(

result = schema.deserialize(valid_params)

assert result == dict(
valid_params, privacy_accepted=True, comms_opt_in=None, csrf_token=None
)
assert result == dict(valid_params, privacy_accepted=True, comms_opt_in=None)

@pytest.fixture
def valid_params(self):
Expand Down Expand Up @@ -194,18 +191,6 @@ def test_it_is_valid_if_email_same_as_users_existing_email(

schema.deserialize({"email": user.email, "password": "flibble"})

def test_it_is_invalid_if_csrf_token_missing(self, pyramid_request, schema):
del pyramid_request.headers["X-CSRF-Token"]

with pytest.raises(BadCSRFToken):
schema.deserialize({"email": "foo@bar.com", "password": "flibble"})

def test_it_is_invalid_if_csrf_token_wrong(self, pyramid_request, schema):
pyramid_request.headers["X-CSRF-Token"] = "WRONG"

with pytest.raises(BadCSRFToken):
schema.deserialize({"email": "foo@bar.com", "password": "flibble"})

def test_it_is_invalid_if_password_wrong(self, schema, user_password_service):
user_password_service.check_password.return_value = False

Expand Down
24 changes: 0 additions & 24 deletions tests/unit/h/schemas/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import colander
import pytest
from pyramid import csrf
from pyramid.exceptions import BadCSRFToken

from h.schemas import ValidationError
from h.schemas.base import CSRFSchema, JSONSchema, enum_type
Expand All @@ -25,28 +23,6 @@ class ExampleJSONSchema(JSONSchema):
}


class TestCSRFSchema:
def test_raises_badcsrf_with_bad_csrf(self, pyramid_request):
schema = ExampleCSRFSchema().bind(request=pyramid_request)

with pytest.raises(BadCSRFToken):
schema.deserialize({})

def test_ok_with_good_csrf(self, pyramid_request):
csrf_token = csrf.get_csrf_token(pyramid_request)
pyramid_request.POST["csrf_token"] = csrf_token
schema = ExampleCSRFSchema().bind(request=pyramid_request)

# Does not raise
schema.deserialize({})

def test_ok_with_good_csrf_from_header(self, pyramid_csrf_request):
schema = ExampleCSRFSchema().bind(request=pyramid_csrf_request)

# Does not raise
schema.deserialize({})


class TestJSONSchema:
def test_it_raises_for_unsupported_schema_versions(self):
class BadSchema(JSONSchema):
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/h/schemas/forms/accounts/login_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import colander
import pytest
from pyramid.exceptions import BadCSRFToken

from h.schemas.forms.accounts import LoginSchema
from h.services.user import UserNotActivated
Expand Down Expand Up @@ -46,12 +45,6 @@ def test_it_returns_user_when_valid(

assert result["user"] is user

def test_invalid_with_bad_csrf(self, pyramid_request):
schema = LoginSchema().bind(request=pyramid_request)

with pytest.raises(BadCSRFToken):
schema.deserialize({"username": "jeannie", "password": "cake"})

def test_invalid_with_inactive_user(self, pyramid_csrf_request, user_service):
schema = LoginSchema().bind(request=pyramid_csrf_request)
user_service.fetch_for_login.side_effect = UserNotActivated()
Expand Down
13 changes: 0 additions & 13 deletions tests/unit/h/schemas/forms/admin/group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import colander
import pytest
from pyramid.exceptions import BadCSRFToken

from h.models.group import (
GROUP_DESCRIPTION_MAX_LENGTH,
Expand All @@ -18,18 +17,6 @@ class TestAdminGroupSchema:
def test_it_allows_with_valid_data(self, group_data, bound_schema):
bound_schema.deserialize(group_data)

def test_it_raises_if_csrf_token_missing(self, group_data, bound_schema):
del bound_schema.bindings["request"].headers["X-CSRF-Token"]

with pytest.raises(BadCSRFToken):
bound_schema.deserialize(group_data)

def test_it_raises_if_csrf_token_wrong(self, group_data, bound_schema):
bound_schema.bindings["request"].headers["X-CSRF-Token"] = "foobar"

with pytest.raises(BadCSRFToken):
bound_schema.deserialize(group_data)

def test_it_raises_if_name_too_short(self, group_data, bound_schema):
too_short_name = "a" * (GROUP_NAME_MIN_LENGTH - 1)
group_data["name"] = too_short_name
Expand Down

0 comments on commit 78d79ed

Please sign in to comment.