Source code for watson.form.decorators

# -*- coding: utf-8 -*-
from uuid import uuid4
import hashlib
from watson.common.datastructures import MultiDict
from watson.form import fields
from watson import validators


[docs]def has_csrf(cls): """Adds csrf protection to the form. Adds a new field named 'csrf_token' to the form and overrides the set data method to retrieve the correct token. If a form is csrf protected then a session object must be passed to the __init__ method so that a token can be created (if not already). Example: .. code-block:: python @has_csrf class MyForm(Form): username = fields.Text(required=True) password = fields.Password(required=True) """ class CsrfProtectedForm(cls): csrf_token = fields.Csrf() def __init__(self, name, method='post', action=None, detect_multipart=True, session=None, **kwargs): """Inititalize the form and set some default attributes. Args: name (string): the name of the form method (string): the http method to use action (string): the url to submit the form to detect_multipart (boolean): automatically set multipart/form-data session (watson.http.session.StorageMixin): the session in which to store the token """ if not session: raise ValueError( 'A session must be assigned to the form for validation.') super( CsrfProtectedForm, self).__init__( name, method, action, detect_multipart, **kwargs) token_name = '{0}_csrf_token'.format(self.name) if token_name not in session or not session.get(token_name): token = '{0}{1}{2}'.format(token_name, uuid4().hex, session.id) actual = hashlib.sha256(token.encode('utf-8')).hexdigest() session[token_name] = actual for validator in self.fields['csrf_token'].validators: if isinstance(validator, validators.Csrf): validator.token = session[token_name] self.csrf_token = session[token_name] @property def data(self): # required for the setter override return super(CsrfProtectedForm, self).data @data.setter def data(self, data): """Sets the data for the form. Override the existing set data method and inject the csrf token into the form. """ token_name = '{0}_csrf_token'.format(self.name) if hasattr(data, 'post'): raw_data = MultiDict() raw_data.update(data.files.items()) for key, value in data.post.items(): if key.endswith('_csrf_token'): raw_data['csrf_token'] = value else: raw_data[key] = value else: if token_name in data: data['csrf_token'] = data[token_name] del data[token_name] raw_data = data self._set_data_on_fields(raw_data) def close(self): """Render the end tag of the form. Automatically renders the csrf field within the form. """ return ( '{0}{1}'.format( self.fields['csrf_token'], super( CsrfProtectedForm, self).close())) return CsrfProtectedForm