From 281af2b255237f918502f3cb7e17a52165b64340 Mon Sep 17 00:00:00 2001 From: Mike Shoup Date: Fri, 26 Oct 2018 09:45:16 -0600 Subject: [PATCH] Improve unit tests for auth.go --- auth.go | 2 +- auth_test.go | 141 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 auth_test.go diff --git a/auth.go b/auth.go index dcdd167..2856df8 100644 --- a/auth.go +++ b/auth.go @@ -1,6 +1,6 @@ // Copyright 2014 Manu Martinez-Almeida. All rights reserved. // Use of this source code is governed by a MIT style -// license that can be found in the LICENSE file. +// license that can be found at: https://github.com/gin-gonic/gin/blob/master/LICENSE // Modified to remove the WWW-Authenticate header for uses in TempGopher diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..e1d4657 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,141 @@ +// Copyright 2014 Manu Martinez-Almeida. All rights reserved. +// Use of this source code is governed by a MIT style +// license that can be found at: https://github.com/gin-gonic/gin/blob/master/LICENSE +// Original source: https://github.com/gin-gonic/gin/blob/master/auth_test.go + +// Modified to remove the WWW-Authenticate header for uses in TempGopher + +package main + +import ( + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestBasicAuth(t *testing.T) { + pairs := processAccounts(gin.Accounts{ + "admin": "password", + "foo": "bar", + "bar": "foo", + }) + + assert.Len(t, pairs, 3) + assert.Contains(t, pairs, authPair{ + user: "bar", + value: "Basic YmFyOmZvbw==", + }) + assert.Contains(t, pairs, authPair{ + user: "foo", + value: "Basic Zm9vOmJhcg==", + }) + assert.Contains(t, pairs, authPair{ + user: "admin", + value: "Basic YWRtaW46cGFzc3dvcmQ=", + }) +} + +func TestBasicAuthFails(t *testing.T) { + assert.Panics(t, func() { processAccounts(nil) }) + assert.Panics(t, func() { + processAccounts(gin.Accounts{ + "": "password", + "foo": "bar", + }) + }) +} + +func TestBasicAuthSearchCredential(t *testing.T) { + pairs := processAccounts(gin.Accounts{ + "admin": "password", + "foo": "bar", + "bar": "foo", + }) + + user, found := pairs.searchCredential(authorizationHeader("admin", "password")) + assert.Equal(t, "admin", user) + assert.True(t, found) + + user, found = pairs.searchCredential(authorizationHeader("foo", "bar")) + assert.Equal(t, "foo", user) + assert.True(t, found) + + user, found = pairs.searchCredential(authorizationHeader("bar", "foo")) + assert.Equal(t, "bar", user) + assert.True(t, found) + + user, found = pairs.searchCredential(authorizationHeader("admins", "password")) + assert.Empty(t, user) + assert.False(t, found) + + user, found = pairs.searchCredential(authorizationHeader("foo", "bar ")) + assert.Empty(t, user) + assert.False(t, found) + + user, found = pairs.searchCredential("") + assert.Empty(t, user) + assert.False(t, found) +} + +func TestBasicAuthAuthorizationHeader(t *testing.T) { + assert.Equal(t, "Basic YWRtaW46cGFzc3dvcmQ=", authorizationHeader("admin", "password")) +} + +func TestBasicAuthSucceed(t *testing.T) { + accounts := gin.Accounts{"admin": "password"} + router := gin.New() + router.Use(BasicAuth(accounts)) + router.GET("/login", func(c *gin.Context) { + c.String(http.StatusOK, c.MustGet(gin.AuthUserKey).(string)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/login", nil) + req.Header.Set("Authorization", authorizationHeader("admin", "password")) + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "admin", w.Body.String()) +} + +func TestBasicAuth401(t *testing.T) { + called := false + accounts := gin.Accounts{"foo": "bar"} + router := gin.New() + router.Use(BasicAuth(accounts)) + router.GET("/login", func(c *gin.Context) { + called = true + c.String(http.StatusOK, c.MustGet(gin.AuthUserKey).(string)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/login", nil) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) + router.ServeHTTP(w, req) + + assert.False(t, called) + assert.Equal(t, http.StatusUnauthorized, w.Code) +} + +func TestBasicAuth401WithCustomRealm(t *testing.T) { + called := false + accounts := gin.Accounts{"foo": "bar"} + router := gin.New() + router.Use(BasicAuth(accounts)) + router.GET("/login", func(c *gin.Context) { + called = true + c.String(http.StatusOK, c.MustGet(gin.AuthUserKey).(string)) + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/login", nil) + req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:password"))) + router.ServeHTTP(w, req) + + assert.False(t, called) + assert.Equal(t, http.StatusUnauthorized, w.Code) +}