You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@allura.apache.org by br...@apache.org on 2022/09/14 18:22:19 UTC

[allura] 01/07: [#8461] convert oauth tests to not mock the oauth library, use requests_oauthlib as a helper to build requests instead

This is an automated email from the ASF dual-hosted git repository.

brondsem pushed a commit to branch db/8461
in repository https://gitbox.apache.org/repos/asf/allura.git

commit 50a05541e9cb7a6a652dee209277e1d5e731afdf
Author: Dave Brondsema <db...@slashdotmedia.com>
AuthorDate: Wed Sep 7 12:45:03 2022 -0400

    [#8461] convert oauth tests to not mock the oauth library, use requests_oauthlib as a helper to build requests instead
---
 Allura/allura/tests/decorators.py           |  19 +-
 Allura/allura/tests/functional/test_auth.py | 296 ++++++++++++++--------------
 AlluraTest/alluratest/controller.py         |  18 ++
 3 files changed, 180 insertions(+), 153 deletions(-)

diff --git a/Allura/allura/tests/decorators.py b/Allura/allura/tests/decorators.py
index 6c561c228..34854ad0f 100644
--- a/Allura/allura/tests/decorators.py
+++ b/Allura/allura/tests/decorators.py
@@ -209,18 +209,29 @@ def out_audits(*messages, **kwargs):
 
 
 # not a decorator but use it with LogCapture() context manager
-def assert_logmsg_and_no_warnings_or_errors(logs, msg):
+def assert_logmsg(logs, msg, maxlevel=logging.CRITICAL+1):
     """
+    can also use logs.check() or logs.check_present()
     :param testfixtures.logcapture.LogCapture logs: LogCapture() instance
-    :param str msg: Message to look for
+    :param str msg: Message substring to look for
     """
     found_msg = False
     for r in logs.records:
         if msg in r.getMessage():
             found_msg = True
-        if r.levelno > logging.INFO:
+        if r.levelno > maxlevel:
             raise AssertionError(f'unexpected log {r.levelname} {r.getMessage()}')
-    assert found_msg, 'Did not find {} in logs: {}'.format(msg, '\n'.join([r.getMessage() for r in logs.records]))
+    assert found_msg, \
+        'Did not find "{}" in these logs: {}'.format(msg, '\n'.join([r.getMessage() for r in logs.records]))
+
+
+def assert_logmsg_and_no_warnings_or_errors(logs, msg):
+    """
+    can also use logs.check() or logs.check_present()
+    :param testfixtures.logcapture.LogCapture logs: LogCapture() instance
+    :param str msg: Message substring to look for
+    """
+    return assert_logmsg(logs, msg, maxlevel=logging.INFO)
 
 
 def assert_equivalent_urls(url1, url2):
diff --git a/Allura/allura/tests/functional/test_auth.py b/Allura/allura/tests/functional/test_auth.py
index c6b8fae11..1d15220d5 100644
--- a/Allura/allura/tests/functional/test_auth.py
+++ b/Allura/allura/tests/functional/test_auth.py
@@ -14,17 +14,22 @@
 #       KIND, either express or implied.  See the License for the
 #       specific language governing permissions and limitations
 #       under the License.
+from __future__ import annotations
 
 import calendar
 from base64 import b32encode
 from datetime import datetime, time, timedelta
 from time import time as time_time
 import json
+
 from six.moves.urllib.parse import urlparse, parse_qs
 from six.moves.urllib.parse import urlencode
 
 from bson import ObjectId
 import re
+
+from testfixtures import LogCapture
+
 from ming.orm.ormsession import ThreadLocalORMSession, session
 from tg import config, expose
 from mock import patch, Mock
@@ -40,17 +45,15 @@ from alluratest.tools import (
     assert_false,
 )
 from tg import tmpl_context as c, app_globals as g
-import oauth2
 
 from allura.tests import TestController
 from allura.tests import decorators as td
-from allura.tests.decorators import audits, out_audits
-from alluratest.controller import setup_trove_categories, TestRestApiBase
+from allura.tests.decorators import audits, out_audits, assert_logmsg
+from alluratest.controller import setup_trove_categories, TestRestApiBase, oauth1_webtest
 from allura import model as M
 from allura.lib import plugin
 from allura.lib import helpers as h
 from allura.lib.multifactor import TotpService, RecoveryCodeService
-import six
 
 
 def unentity(s):
@@ -1854,109 +1857,55 @@ class TestOAuth(TestController):
             M.OAuthAccessToken.for_user(M.User.by_username('test-admin')), [])
 
     def test_interactive(self):
-        with mock.patch('allura.controllers.rest.oauth.Server') as Server, \
-                mock.patch('allura.controllers.rest.oauth.Request') as Request:   # these are the oauth2 libs
-            user = M.User.by_username('test-admin')
-            M.OAuthConsumerToken(
-                api_key='api_key',
-                user_id=user._id,
-                description='ctok_desc',
-            )
-            ThreadLocalORMSession.flush_all()
-            Request.from_request.return_value = {
-                'oauth_consumer_key': 'api_key',
-                'oauth_callback': 'http://my.domain.com/callback',
-            }
-            r = self.app.post('/rest/oauth/request_token', params={})
-            rtok = parse_qs(r.text)['oauth_token'][0]
-            r = self.app.post('/rest/oauth/authorize',
-                              params={'oauth_token': rtok})
-            r = r.forms[0].submit('yes')
-            assert r.location.startswith('http://my.domain.com/callback')
-            pin = parse_qs(urlparse(r.location).query)['oauth_verifier'][0]
-            Request.from_request.return_value = {
-                'oauth_consumer_key': 'api_key',
-                'oauth_token': rtok,
-                'oauth_verifier': pin,
-            }
-            r = self.app.get('/rest/oauth/access_token')
-            atok = parse_qs(r.text)
-            assert_equal(len(atok['oauth_token']), 1)
-            assert_equal(len(atok['oauth_token_secret']), 1)
-
-        # now use the tokens & secrets to make a full OAuth request:
-        oauth_secret = atok['oauth_token_secret'][0]
-        oauth_token = atok['oauth_token'][0]
-        consumer = oauth2.Consumer('api_key', oauth_secret)
-        M.OAuthConsumerToken.consumer = consumer
-        access_token = oauth2.Token(oauth_token, oauth_secret)
-        oauth_client = oauth2.Client(consumer, access_token)
-        # use the oauth2 lib, but intercept the request and then send it to self.app.get
-        with mock.patch('oauth2.httplib2.Http.request', name='hl2req') as oa2_req:
-            oauth_client.request('http://localhost/rest/p/test/', 'GET')
-            oa2url = oa2_req.call_args[0][1]
-            oa2url = oa2url.replace('http://localhost', '')
-            # print(oa2url)
-            oa2kwargs = oa2_req.call_args[1]
-        self.app.get(oa2url, headers=oa2kwargs['headers'], status=200)
-        self.app.get(oa2url.replace('oauth_signature=', 'removed='), headers=oa2kwargs['headers'], status=401)
-
-    @mock.patch('allura.controllers.rest.oauth.Server')
-    @mock.patch('allura.controllers.rest.oauth.Request')
-    def test_request_token_valid(self, Request, Server):
-        M.OAuthConsumerToken.consumer = mock.Mock()
-        user = M.User.by_username('test-user')
-        consumer_token = M.OAuthConsumerToken(
-            api_key='api_key',
-            user_id=user._id,
-        )
-        ThreadLocalORMSession.flush_all()
-        req = Request.from_request.return_value = {'oauth_consumer_key': 'api_key'}
-        r = self.app.post('/rest/oauth/request_token', params={'key': 'value'})
-
-        # dict-ify webob.EnvironHeaders
-        call = Request.from_request.call_args_list[0]
-        call[1]['headers'] = dict(call[1]['headers'])
-        # then check equality
-        assert_equal(Request.from_request.call_args_list, [
-            mock.call('POST', 'http://localhost/rest/oauth/request_token',
-                      headers={'Host': 'localhost:80',
-                               'Content-Type': 'application/x-www-form-urlencoded',
-                               'Content-Length': '9'},
-                      parameters={'key': 'value'},
-                      query_string='')
-        ])
-        Server().verify_request.assert_called_once_with(req, consumer_token.consumer, None)
-        request_token = M.OAuthRequestToken.query.get(consumer_token_id=consumer_token._id)
-        assert_is_not_none(request_token)
-        assert_equal(r.text, request_token.to_string())
-
-    @mock.patch('allura.controllers.rest.oauth.Server')
-    @mock.patch('allura.controllers.rest.oauth.Request')
-    def test_request_token_no_consumer_token_matching(self, Request, Server):
-        Request.from_request.return_value = {'oauth_consumer_key': 'api_key'}
-        self.app.post('/rest/oauth/request_token',
-                      params={'key': 'value'}, status=401)
-
-    @mock.patch('allura.controllers.rest.oauth.Server')
-    @mock.patch('allura.controllers.rest.oauth.Request')
-    def test_request_token_no_consumer_token_given(self, Request, Server):
-        Request.from_request.return_value = {}
-        self.app.post('/rest/oauth/request_token', params={'key': 'value'}, status=401)
-
-    @mock.patch('allura.controllers.rest.oauth.Server')
-    @mock.patch('allura.controllers.rest.oauth.Request')
-    def test_request_token_invalid(self, Request, Server):
-        Server().verify_request.side_effect = oauth2.Error('test_request_token_invalid')
-        M.OAuthConsumerToken.consumer = mock.Mock()
-        user = M.User.by_username('test-user')
+        user = M.User.by_username('test-admin')
         M.OAuthConsumerToken(
             api_key='api_key',
+            secret_key='dummy-client-secret',
             user_id=user._id,
+            description='ctok_desc',
         )
         ThreadLocalORMSession.flush_all()
-        Request.from_request.return_value = {'oauth_consumer_key': 'api_key'}
-        self.app.post('/rest/oauth/request_token', params={'key': 'value'}, status=401)
+        oauth_params = dict(
+            client_key='api_key',
+            client_secret='dummy-client-secret',
+            callback_uri='http://my.domain.com/callback',
+        )
+        r = self.app.post(*oauth1_webtest('/rest/oauth/request_token', oauth_params, method='POST'))
+        rtok = parse_qs(r.text)['oauth_token'][0]
+        rsecr = parse_qs(r.text)['oauth_token_secret'][0]
+        assert rtok
+        assert rsecr
+        r = self.app.post('/rest/oauth/authorize',
+                          params={'oauth_token': rtok})
+        r = r.forms[0].submit('yes')
+        assert r.location.startswith('http://my.domain.com/callback')
+        pin = parse_qs(urlparse(r.location).query)['oauth_verifier'][0]
+        assert pin
+
+        oauth_params = dict(
+            client_key='api_key',
+            client_secret='dummy-client-secret',
+            resource_owner_key=rtok,
+            resource_owner_secret=rsecr,
+            verifier=pin,
+        )
+        r = self.app.get(*oauth1_webtest('/rest/oauth/access_token', oauth_params))
+        atok = parse_qs(r.text)
+        assert_equal(len(atok['oauth_token']), 1)
+        assert_equal(len(atok['oauth_token_secret']), 1)
+
+        # now use the tokens & secrets to make a full OAuth request:
+        oauth_token = atok['oauth_token'][0]
+        oauth_secret = atok['oauth_token_secret'][0]
+        oaurl, oaparams, oahdrs = oauth1_webtest('/rest/p/test/', dict(
+            client_key='api_key',
+            client_secret='dummy-client-secret',
+            resource_owner_key=oauth_token,
+            resource_owner_secret=oauth_secret,
+            signature_type='query'
+        ))
+        self.app.get(oaurl, oaparams, oahdrs, status=200)
+        self.app.get(oaurl.replace('oauth_signature=', 'removed='), oaparams, oahdrs, status=401)
 
     def test_authorize_ok(self):
         user = M.User.by_username('test-admin')
@@ -2048,22 +1997,72 @@ class TestOAuth(TestController):
         r = self.app.post('/rest/oauth/do_authorize', params={'yes': '1', 'oauth_token': 'api_key'})
         assert r.location.startswith('http://my.domain.com/callback?myparam=foo&oauth_token=api_key&oauth_verifier=')
 
-    @mock.patch('allura.controllers.rest.oauth.Request')
-    def test_access_token_no_consumer(self, Request):
-        Request.from_request.return_value = {
-            'oauth_consumer_key': 'api_key',
-            'oauth_token': 'api_key',
-            'oauth_verifier': 'good',
-        }
-        self.app.get('/rest/oauth/access_token', status=401)
-
-    @mock.patch('allura.controllers.rest.oauth.Request')
-    def test_access_token_no_request(self, Request):
-        Request.from_request.return_value = {
-            'oauth_consumer_key': 'api_key',
-            'oauth_token': 'api_key',
-            'oauth_verifier': 'good',
-        }
+
+class TestOAuthRequestToken(TestController):
+
+    oauth_params = dict(
+        client_key='api_key',
+        client_secret='dummy-client-secret',
+    )
+
+    def test_request_token_valid(self):
+        user = M.User.by_username('test-user')
+        consumer_token = M.OAuthConsumerToken(
+            api_key='api_key',
+            secret_key='dummy-client-secret',
+            user_id=user._id,
+        )
+        ThreadLocalORMSession.flush_all()
+        r = self.app.post(*oauth1_webtest('/rest/oauth/request_token', self.oauth_params, method='POST'))
+
+        request_token = M.OAuthRequestToken.query.get(consumer_token_id=consumer_token._id)
+        assert_is_not_none(request_token)
+        assert_equal(r.text, request_token.to_string())
+
+    def test_request_token_no_consumer_token_matching(self):
+        with LogCapture() as logs:
+            self.app.post(*oauth1_webtest('/rest/oauth/request_token', self.oauth_params), status=401)
+        assert_logmsg(logs, 'Invalid consumer token')
+
+    def test_request_token_no_consumer_token_given(self):
+        oauth_params = self.oauth_params.copy()
+        oauth_params['signature_type'] = 'query'  # so we can more easily remove a param next
+        url, params, hdrs = oauth1_webtest('/rest/oauth/request_token', oauth_params)
+        url = url.replace('oauth_consumer_key', 'gone')
+        with LogCapture() as logs:
+            self.app.post(url, params, hdrs, status=401)
+        assert_logmsg(logs, 'Invalid consumer token')
+
+    def test_request_token_invalid(self):
+        user = M.User.by_username('test-user')
+        M.OAuthConsumerToken(
+            api_key='api_key',
+            user_id=user._id,
+            secret_key='dummy-client-secret--INVALID',
+        )
+        ThreadLocalORMSession.flush_all()
+        with LogCapture() as logs:
+            self.app.post(*oauth1_webtest('/rest/oauth/request_token', self.oauth_params, method='POST'),
+                          status=401)
+        assert_logmsg(logs, "Invalid signature <class 'oauth2.Error'> Invalid signature.")
+
+
+class TestOAuthAccessToken(TestController):
+
+    oauth_params = dict(
+        client_key='api_key',
+        client_secret='dummy-client-secret',
+        resource_owner_key='api_key',
+        resource_owner_secret='dummy-token-secret',
+        verifier='good',
+    )
+
+    def test_access_token_no_consumer(self):
+        with LogCapture() as logs:
+            self.app.get(*oauth1_webtest('/rest/oauth/access_token', self.oauth_params), status=401)
+        assert_logmsg(logs, 'Invalid consumer token')
+
+    def test_access_token_no_request(self):
         user = M.User.by_username('test-admin')
         M.OAuthConsumerToken(
             api_key='api_key',
@@ -2071,15 +2070,11 @@ class TestOAuth(TestController):
             description='ctok_desc',
         )
         ThreadLocalORMSession.flush_all()
-        self.app.get('/rest/oauth/access_token', status=401)
-
-    @mock.patch('allura.controllers.rest.oauth.Request')
-    def test_access_token_bad_pin(self, Request):
-        Request.from_request.return_value = {
-            'oauth_consumer_key': 'api_key',
-            'oauth_token': 'api_key',
-            'oauth_verifier': 'bad',
-        }
+        with LogCapture() as logs:
+            self.app.get(*oauth1_webtest('/rest/oauth/access_token', self.oauth_params), status=401)
+        assert_logmsg(logs, 'Invalid request token')
+
+    def test_access_token_bad_pin(self):
         user = M.User.by_username('test-admin')
         ctok = M.OAuthConsumerToken(
             api_key='api_key',
@@ -2094,21 +2089,20 @@ class TestOAuth(TestController):
             validation_pin='good',
         )
         ThreadLocalORMSession.flush_all()
-        self.app.get('/rest/oauth/access_token', status=401)
-
-    @mock.patch('allura.controllers.rest.oauth.Server')
-    @mock.patch('allura.controllers.rest.oauth.Request')
-    def test_access_token_bad_sig(self, Request, Server):
-        Request.from_request.return_value = {
-            'oauth_consumer_key': 'api_key',
-            'oauth_token': 'api_key',
-            'oauth_verifier': 'good',
-        }
+        with LogCapture() as logs:
+            oauth_params = self.oauth_params.copy()
+            oauth_params['verifier'] = 'bad'
+            self.app.get(*oauth1_webtest('/rest/oauth/access_token', oauth_params),
+                         status=401)
+        assert_logmsg(logs, 'Invalid verifier')
+
+    def test_access_token_bad_sig(self):
         user = M.User.by_username('test-admin')
         ctok = M.OAuthConsumerToken(
             api_key='api_key',
             user_id=user._id,
             description='ctok_desc',
+            secret_key='dummy-client-secret',
         )
         M.OAuthRequestToken(
             api_key='api_key',
@@ -2116,34 +2110,38 @@ class TestOAuth(TestController):
             callback='http://my.domain.com/callback?myparam=foo',
             user_id=user._id,
             validation_pin='good',
+            secret_key='dummy-token-secret--INVALID',
         )
         ThreadLocalORMSession.flush_all()
-        Server().verify_request.side_effect = oauth2.Error('test_access_token_bad_sig')
-        self.app.get('/rest/oauth/access_token', status=401)
-
-    @mock.patch('allura.controllers.rest.oauth.Server')
-    @mock.patch('allura.controllers.rest.oauth.Request')
-    def test_access_token_ok(self, Request, Server):
-        Request.from_request.return_value = {
-            'oauth_consumer_key': 'api_key',
-            'oauth_token': 'api_key',
-            'oauth_verifier': 'good',
-        }
+        with LogCapture() as logs:
+            self.app.get(*oauth1_webtest('/rest/oauth/access_token', self.oauth_params), status=401)
+        assert_logmsg(logs, "Invalid signature <class 'oauth2.Error'> Invalid signature.")
+
+    def test_access_token_ok(self):
         user = M.User.by_username('test-admin')
         ctok = M.OAuthConsumerToken(
             api_key='api_key',
+            secret_key='dummy-client-secret',
             user_id=user._id,
             description='ctok_desc',
         )
         M.OAuthRequestToken(
             api_key='api_key',
+            secret_key='dummy-token-secret',
             consumer_token_id=ctok._id,
             callback='http://my.domain.com/callback?myparam=foo',
             user_id=user._id,
             validation_pin='good',
         )
         ThreadLocalORMSession.flush_all()
-        r = self.app.get('/rest/oauth/access_token')
+
+        r = self.app.get(*oauth1_webtest('/rest/oauth/access_token', self.oauth_params))
+        atok = parse_qs(r.text)
+        assert_equal(len(atok['oauth_token']), 1)
+        assert_equal(len(atok['oauth_token_secret']), 1)
+
+        oauth_params = dict(self.oauth_params, signature_type='query')
+        r = self.app.get(*oauth1_webtest('/rest/oauth/access_token', oauth_params))
         atok = parse_qs(r.text)
         assert_equal(len(atok['oauth_token']), 1)
         assert_equal(len(atok['oauth_token_secret']), 1)
diff --git a/AlluraTest/alluratest/controller.py b/AlluraTest/alluratest/controller.py
index 4c7c01bc4..e1e73082e 100644
--- a/AlluraTest/alluratest/controller.py
+++ b/AlluraTest/alluratest/controller.py
@@ -16,6 +16,8 @@
 #       under the License.
 
 """Unit and functional test suite for allura."""
+from __future__ import annotations
+
 import os
 import six.moves.urllib.request
 import six.moves.urllib.parse
@@ -37,6 +39,8 @@ import ew
 from ming.orm import ThreadLocalORMSession
 import ming.orm
 import pkg_resources
+import requests
+import requests_oauthlib
 
 from allura import model as M
 from allura.command import CreateTroveCategoriesCommand
@@ -283,3 +287,17 @@ class TestRestApiBase(TestController):
 
     def api_delete(self, path, wrap_args=None, user='test-admin', status=None, **params):
         return self._api_call('DELETE', path, wrap_args, user, status, **params)
+
+
+def oauth1_webtest(url: str, oauth_kwargs: dict, method='GET') -> tuple[str, dict, dict]:
+    oauth1 = requests_oauthlib.OAuth1(**oauth_kwargs)
+    req = requests.Request(method, f'http://localhost{url}').prepare()
+    oauth1(req)
+    return request2webtest(req)
+
+
+def request2webtest(req: requests.PreparedRequest) -> tuple[str, dict, dict]:
+    url = req.url
+    params = {}
+    headers = {k: v.decode() for k,v in req.headers.items()}
+    return url, params, headers