认证鉴权
前言
以个人目前的工作经验来看,web应用底层架构中最折腾的就是认证鉴权部分,动不动就要对接啥啥啥玩意,出了安全篓子的时候第一个被挂出来批斗。如果公司内部有安全团队,隔三差五还得配合抗压。技术方面还要熟悉各种加密算法、认证方案、鉴权方案……
认证鉴权部分一般是中间件来做的,之前介绍中间件的时候跳过了,这里单独拎出来说,但也不会说太多,就简单介绍下基本的认证鉴权方案。
BasicAuth
Gin框架内置了gin.BasicAuth()
中间件来处理BasicAuth认证。
服务端示例
package main
import (
"net/http"
"github.com/gin-gonic/gin"
)
func main() {
r := gin.Default()
// 在中间件中定义用户名和密码
authorized := r.Group("/admin", gin.BasicAuth(gin.Accounts{
"foo": "bar",
"austin": "1234",
"lena": "hello2",
"manu": "4321",
}))
authorized.GET("/secrets", func(c *gin.Context) {
// get user, it was set by the BasicAuth middleware
user := c.MustGet(gin.AuthUserKey).(string)
c.JSON(http.StatusOK, gin.H{
"message": "Hello " + user,
})
})
r.Run(":8000")
}
如果不带用户名密码请求/admin/secrets
时将会收到401的响应码
$ curl http://127.0.0.1:8000/admin/secrets -v
* Trying 127.0.0.1:8000...
* Connected to 127.0.0.1 (127.0.0.1) port 8000 (#0)
> GET /admin/secrets HTTP/1.1
> Host: 127.0.0.1:8000
> User-Agent: curl/7.88.1
> Accept: */*
>
< HTTP/1.1 401 Unauthorized
< Www-Authenticate: Basic realm="Authorization Required"
< Date: Mon, 25 Nov 2024 16:24:10 GMT
< Content-Length: 0
<
* Connection #0 to host 127.0.0.1 left intact
使用用户名密码的请求
$ curl -u "foo:bar" http://127.0.0.1:8000/admin/secrets -v
* Trying 127.0.0.1:8000...
* Connected to 127.0.0.1 (127.0.0.1) port 8000 (#0)
* Server auth using Basic with user 'foo'
> GET /admin/secrets HTTP/1.1
> Host: 127.0.0.1:8000
> Authorization: Basic Zm9vOmJhcg==
> User-Agent: curl/7.88.1
> Accept: */*
>
< HTTP/1.1 200 OK
< Content-Type: application/json; charset=utf-8
< Date: Mon, 25 Nov 2024 16:26:38 GMT
< Content-Length: 23
<
* Connection #0 to host 127.0.0.1 left intact
{"message":"Hello foo"}
如果需要更复杂的认证逻辑(例如从数据库中动态验证用户名和密码),可以自定义中间件
func BasicAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
username, password, hasAuth := c.Request.BasicAuth()
if !hasAuth || !validateUser(username, password) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
c.Set(gin.AuthUserKey, username)
c.Next()
}
}
// 验证用户名和密码
func validateUser(username, password string) bool {
// 示例:检查用户名和密码是否匹配
return username == "admin" && password == "password123"
}
func main() {
r := gin.Default()
// 自定义 Basic Auth 中间件
authorized := r.Group("/", BasicAuthMiddleware())
authorized.GET("/secure", func(c *gin.Context) {
user := c.MustGet(gin.AuthUserKey).(string)
c.JSON(http.StatusOK, gin.H{
"message": "Welcome " + user,
})
})
r.Run(":8080")
}
API Key
API Key指的是客户端通过一个唯一密钥来标识和验证自己,一般根据请求头中的某个参数鉴权。一旦泄露就会造成安全风险,一般只在内部服务使用。
package main
import (
"net/http"
"github.com/gin-gonic/gin"
)
func APIKeyAuthMiddleware(requiredKey string) gin.HandlerFunc {
return func(c *gin.Context) {
// 从 HTTP 请求头中提取 API Key
apiKey := c.GetHeader("X-API-KEY")
if apiKey == "" || apiKey != requiredKey {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
c.Next()
}
}
func main() {
r := gin.Default()
var validAPIKey string = "qwer1234zxcv"
authorized := r.Group("/admin", APIKeyAuthMiddleware(validAPIKey))
authorized.GET("/secrets", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"message": "Hello World",
})
})
r.Run(":8000")
}
请求测试
curl -H 'X-API-KEY: qwer1234zxcv' http://127.0.0.1:8000/admin/secrets -v
JWT
上面两种方案相当于每个请求都进行了登录操作,极大增加了登录信息的泄露风险。目前主流做法是,在第一次登录之后产生存在有效期的Token,并将它存储在客户端某个地方,比如浏览器的cookie或localstorage。之后的请求都携带这个Token,在请求到达服务端后,由服务端用这个Token对请求进行认证。因为JWT是无状态的,服务端可以不存储jwt的token。
JWT,全称 JSON Web Token,是一种开放标准(RFC 7519),用于安全地在双方之间传递信息。尤其适用于身份验证和授权场景。JWT 的设计允许信息在各方之间安全地、 compactly(紧凑地)传输,因为其自身包含了所有需要的认证信息,从而减少了需要查询数据库或会话存储的需求。
JWT主要由三部分组成,通过 .
连接:
- Header(头部):描述JWT的元数据,通常包括类型(通常是
JWT
)和使用的签名算法(如HS256
、RS256
等)。 - Payload(载荷):包含声明(claims),即用户的相关信息。这些信息可以是公开的,也可以是私有的,但应避免放入敏感信息,因为该部分可以被解码查看。载荷中的声明可以验证,但不加密。
- Signature(签名):用于验证JWT的完整性和来源。它是通过将Header和Payload分别进行Base64编码后,再与一个秘钥(secret)一起通过指定的算法(如HMAC SHA256)计算得出的。
JWT认证流程
基本jwt认证需要实现以下功能
- API:用户登录、需要认证请求的API
- 数据库:存储用户账户信息
- 生成jwt token的方法
- 中间件:通过中间件校验请求
除了基础功能,一般还需要用户登出功能,让token立即失效。最简单的做法是客户端删掉token,但如果token提前保存了,这个token在失效前还是可以继续用的。还有种做法是将登出的token拉进黑名单,鉴权的时候先判断token在不在黑名单里面。单个实例的话还能处理,多实例的话就要考虑黑名单数据同步的问题了,可能就需要引入redis这样的键值对数据库。既然引入了redis,而且每次都要先查一遍redis,还不如直接存token,每次鉴权先检查token在不在redis,这样jwt就变成有状态了,跟session没什么区别。
如果考虑jwt token泄露导致的安全风险,首先token的过期时间不能太久。但如果token的生效时间太短,用户可能需要频繁重新登录,用户体验会比较差。有种做法是双Token机制,用户登录的时候,服务端返回一个长期的refresh token和一个短期的access token。refresh token的过期时间更长,比如1 天,而access token的过期时间可能只有15分钟。客户端每次请求后端资源的时候携带的是access token,如果access token失效,则用refresh token去获取一个新的access token,然后再用新的access token继续请求后端。refresh token只用来获取新的access token,而不涉及请求具体资源。后端可以将refresh token放到redis等键值对数据库,当用户登出时,后端从redis中清除refresh token,这样客户端就获取不到新的access token。这种双token的设计,一定程度上降低了access token泄露的风险。如果refresh token也泄露了,后端也可以从redis中直接清除refresh token。双token没法保证没有安全问题,只能说是在性能、体验、安全各方面都有一定保障。
以下示例代码仅简单演示用户登录获取jwt token,然后用jwt token访问资源的过程。描述虽然就两句话,但涉及数据库连接、jwt token生成和验证、中间件和用户登录功能。
3.0 代码文件目录
这里代码文件结构参考了一些其他项目,然后又按照个人习惯来组织的。go没有固定的文件结构需要去遵守,按个人喜好即可。
├── conf
│ └── config.yaml
├── go.mod
├── go.sum
├── internal
│ ├── dao
│ │ └── user.go
│ ├── handlers
│ │ ├── common.go
│ │ └── users
│ │ └── users.go
│ └── router
│ └── router.go
├── LICENSE
├── main.go
├── models
│ ├── models.go
│ └── users.go
├── pkg
│ ├── config
│ │ └── config.go
│ ├── db
│ │ ├── rdb.go
│ │ └── redis.go
│ ├── log
│ │ └── log.go
│ ├── middlewares
│ │ └── mids.go
│ ├── token
│ │ ├── token.go
│ │ └── token_test.go
│ └── utils
│ ├── utils.go
│ └── utils_test.go
└── README.md
3.1 准备
- 准备数据库。后续代码用了gorm,理论上来说,sqlite、postgres或mysql都行,这里用的postgres,安装方法可以在网上找找步骤,最简单的方式就是用docker了。用包管理器或源码编译的方式也行。
- 创建一个新的应用目录。基于之前代码修修改改也行
mkdir gin-hello
cd gin-hello
go mod init gin-hello
- 编辑配置文件。配置文件放在
conf
目录
log:
logdir: logs
retention: 7 # 单位: 天
console: true
store:
rdb:
dbtype: postgres # mysql | sqlite | postgres
host: 192.168.0.201
port: 5432
username: app
password: app
dbname: app
dbfile: app.sqlite3 # used for sqlite
redis:
enabled: true
addr: 192.168.0.201:6379
username: ""
password: ""
dbnum: 0
jwt:
secrets: "qwertyuiop"
expire_seconds: 7200
3.2 编写基础功能模块
以下写了:配置模块,用于读取配置文件;日志模块;数据库连接模块;token相关模块;中间件模块;工具模块。
3.2.1 配置模块
按个人习惯,代码文件路径为pkg/config/config.go
。这里用viper
来读取配置文件,用其它库来读取也是可以的。有的项目会单独写个Config
结构体,然后为这个结构体实现各种方法,用的时候实例化一个实例即可。
package config
import (
"log/slog"
"os"
"os/signal"
"syscall"
"github.com/spf13/viper"
)
var v *viper.Viper
func init() {
v = viper.New()
// v.AddConfigPath("conf") // 配置文件的目录
v.AddConfigPath("/home/rainux/workspace/go-learn/gin-hello/conf") // 配置文件的目录
v.SetConfigType("yaml") // 配置文件的类型
v.SetConfigName("config")
err := v.ReadInConfig()
if err != nil {
slog.Error("Error reading config file", "error", err)
panic(err)
}
reloadWithSighup()
}
// 监听SIGHUP信号,用于热加载
func reloadWithSighup() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGHUP)
go func() {
for range sigs {
slog.Info("SIGHUB signal received, reloading config...")
if err := v.ReadInConfig(); err != nil {
slog.Error("Error reloading config file", "error", err)
}
slog.Info("Reload config successfully")
}
}()
}
type webaddr struct {
Host string
Port int
}
func GetWebAddr() webaddr {
host := v.GetString("web.host")
port := v.GetInt("web.port")
return webaddr{Host: host, Port: port}
}
func GetGinMode() string {
slog.Info("gin mode: " + v.GetString("web.mode"))
return v.GetString("web.mode")
}
type logcfg struct {
LogDir string
Retention int
Console bool
}
func GetLogCfg() logcfg {
logDir := v.GetString("log.logdir")
retention := v.GetInt("log.retention")
console := v.GetBool("log.console")
return logcfg{LogDir: logDir, Retention: retention, Console: console}
}
type rdbcfg struct {
Dbtype string
Host string
Port int
Username string
Password string
Dbname string
Dbfile string
}
func GetRdbCfg() rdbcfg {
dbtype := v.GetString("store.rdb.dbtype")
host := v.GetString("store.rdb.host")
port := v.GetInt("store.rdb.port")
username := v.GetString("store.rdb.username")
password := v.GetString("store.rdb.password")
dbname := v.GetString("store.rdb.dbname")
dbfile := v.GetString("store.rdb.dbfile")
return rdbcfg{Dbtype: dbtype, Host: host, Port: port, Username: username, Password: password, Dbname: dbname, Dbfile: dbfile}
}
type rediscfg struct {
Enabled bool
Addr string
Username string
Password string
Dbnum int
}
func GetRedisCfg() rediscfg {
return rediscfg{
Enabled: v.GetBool("store.redis.enabled"),
Addr: v.GetString("store.redis.addr"),
Username: v.GetString("store.redis.username"),
Password: v.GetString("store.redis.password"),
Dbnum: v.GetInt("store.redis.dbnum"),
}
}
type jwtcfg struct {
Secrets string
Expire int
}
func GetJwtCfg() jwtcfg {
return jwtcfg{
Secrets: v.GetString("jwt.secrets"),
Expire: v.GetInt("jwt.expire_seconds"),
}
}
3.2.2 数据库连接模块
如果不是追求特别高性能,建议还是用orm,一方面是方便迁移后端数据库类型,另一方面避免sql注入问题。这里用的是gorm,代码文件路径为pkg/db/rdb.go
。下面代码支持使用postgres
、mysql
和sqlite
,便于切换各种数据库。
package db
import (
"fmt"
"gin-hello/pkg/config"
"strings"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var (
db *gorm.DB
err error
)
func InitDB() error {
cfg := config.GetRdbCfg()
if cfg.Dbtype != "sqlite" {
if cfg.Host == "" || cfg.Dbname == "" || cfg.Username == "" || cfg.Password == "" {
return fmt.Errorf("invalid database configuration")
}
}
switch strings.ToLower(cfg.Dbtype) {
case "mysql":
dsn := getDSNMysql(cfg.Host, cfg.Port, cfg.Username, cfg.Password, cfg.Dbname)
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
return fmt.Errorf("db mysql connect error %s", err)
}
case "postgres":
dsn := getDSNPG(cfg.Host, cfg.Port, cfg.Username, cfg.Password, cfg.Dbname)
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return fmt.Errorf("db postgres connect error %s", err)
}
case "sqlite":
db, err = gorm.Open(sqlite.Open(cfg.Dbfile), &gorm.Config{})
if err != nil {
return fmt.Errorf("db sqlite connect error %s", err)
}
default:
return fmt.Errorf("unsupport dbtype %s", cfg.Dbtype)
}
sqlDB, _ := db.DB()
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(20)
return nil
}
// GetDB returns the database instance
func GetDB() *gorm.DB {
return db
}
func getDSNMysql(host string, port int, user string, password string, dbname string) string {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", user, password, host, port, dbname)
}
func getDSNPG(host string, port int, user string, password string, dbname string) string {
return fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai", host, user, password, dbname, port)
}
3.2.3 日志模块
这里用的是slog
,go 1.22版本后自带的高性能日志库。代码文件路径:pkg/log/log.go
package log
import (
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"time"
"gopkg.in/natefinch/lumberjack.v2"
"gin-hello/pkg/config"
)
var (
logger *slog.Logger
logFile string = "app.log"
logDir string
logRetention int
)
func init() {
logcfg := config.GetLogCfg()
logDir = logcfg.LogDir
logRetention = logcfg.Retention
if _, err := os.Stat(logDir); os.IsNotExist(err) { // 创建日志目录
err := os.Mkdir(logDir, os.ModePerm)
if err != nil {
fmt.Printf("Create log directory '%s' failed\n", logDir)
panic(err)
}
}
logOpts := &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelInfo,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { // 日志时间格式化
if a.Key == slog.TimeKey {
t, ok := a.Value.Any().(time.Time)
if !ok {
return a
}
return slog.String(slog.TimeKey, t.Format("2006-01-02 15:04:05"))
}
return a
},
}
rotateLogger := getRorateLogger()
var multiWriter io.Writer
if logcfg.Console {
multiWriter = io.MultiWriter(rotateLogger)
} else {
multiWriter = io.MultiWriter(os.Stdout, rotateLogger)
}
fileHandler := slog.NewJSONHandler(multiWriter, logOpts)
logger = slog.New(fileHandler)
}
// 配置日志自动切割
func getRorateLogger() *lumberjack.Logger {
logFilePath := filepath.Join(logDir, logFile)
return &lumberjack.Logger{
Filename: logFilePath,
MaxAge: logRetention,
Compress: true,
}
}
func Debug(msg string, args ...any) {
logger.Debug(msg, args...)
}
func Info(msg string, args ...any) {
logger.Info(msg, args...)
}
func Error(msg string, args ...any) {
logger.Error(msg, args...)
}
func Warn(msg string, args ...any) {
logger.Warn(msg, args...)
}
3.2.4 token功能模块
该模块主要用于生成jwt token和验证jwt token。代码文件路径为pkg/token/token.go
package token
import (
"fmt"
"gin-hello/pkg/config"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
)
func GenerateToken(user_id uint) (string, error) {
jwtcfg := config.GetJwtCfg()
claim := jwt.MapClaims{}
claim["uid"] = user_id
claim["iat"] = time.Now().Unix() // issue at
claim["exp"] = time.Now().Add(time.Second * time.Duration(jwtcfg.Expire)).Unix() // expire at
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claim)
return jwtToken.SignedString([]byte(jwtcfg.Secrets))
}
func VerifyToken(c *gin.Context) error {
tokenString := c.GetHeader("X-Token")
_, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
jwtcfg := config.GetJwtCfg()
return []byte(jwtcfg.Secrets), nil
})
if err != nil {
return err
}
return nil
}
3.2.5 中间件模块
这里只实现了一个认证中间件,代码文件路径:pkg/middlewares/mids.go
package middlewares
import (
"gin-hello/pkg/token"
"net/http"
"github.com/gin-gonic/gin"
)
type mids struct{}
func NewMids() *mids {
return &mids{}
}
func (m *mids) JwtAuth() gin.HandlerFunc {
return func(c *gin.Context) {
err := token.VerifyToken(c)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"code": http.StatusUnauthorized,
"message": "unauthorized",
"data": map[string]any{
"error": err.Error(),
},
})
c.Abort()
return
}
c.Next()
}
}
3.2.6 工具类模块
这个模块包括一些杂七杂八的工具,代码文件路径pkg/utils/utils.go
package utils
import (
"math/rand"
"time"
"golang.org/x/crypto/bcrypt"
)
func HashString(s string) (string, error) {
hashedString, err := bcrypt.GenerateFromPassword([]byte(s), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedString), nil
}
func VerifyPassword(password, hashedPassword string) error {
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
}
func GenerateShortUrl(length uint) string {
if length < 6 {
length = 6
}
var charset string = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
seed := rand.NewSource(time.Now().UnixNano())
r := rand.New(seed)
result := make([]byte, length)
for i := range result {
result[i] = charset[r.Intn(len(charset))]
}
return string(result)
}
3.3 模型类
模型类用于映射数据表,代码文件路径为models/users.go
package models
import (
"html"
"strings"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type User struct {
gorm.Model
Username string `gorm:"size:255;not null;unique" json:"username"`
Password string `gorm:"size:255;not null;" json:"password"`
}
func (u *User) TableName() string {
return "users"
}
// 使用gorm的hook在新增用户数据前对密码进行hash加密
func (u *User) BeforeSave(tx *gorm.DB) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.Password = string(hashedPassword)
u.Username = html.EscapeString(strings.TrimSpace(u.Username))
return nil
}
添加自动创建表的 函数,如果表结构要自定义,可跳过Migrate的函数。代码文件路径:models/models.go
package models
import "gin-hello/pkg/db"
func MigrateTables() error {
rdb := db.GetDB()
err := rdb.AutoMigrate(&User{})
return err
}
func DropMigrateTables() error {
rdb := db.GetDB()
err := rdb.Migrator().DropTable(&User{})
return err
}
3.4 路由类
路由类用于将路由绑定到路由函数,代码文件为internal/router/router.go
package router
import (
"gin-hello/internal/handlers"
"gin-hello/internal/handlers/users"
"gin-hello/pkg/middlewares"
"github.com/gin-gonic/gin"
)
var mids = middlewares.NewMids()
func NewRouter() *gin.Engine {
r := gin.Default()
r.GET("/health", handlers.GetHealth)
apiGroup := r.Group("/api")
loadUserHandler(apiGroup)
return r
}
func loadUserHandler(r *gin.RouterGroup) {
handler := users.NewUserHandler()
v1Group := r.Group("/v1")
{
v1GroupUser := v1Group.Group("/user")
{
v1GroupUser.POST("/register", handler.Register)
v1GroupUser.POST("/login", handler.Login)
v1GroupUser.POST("/logout", handler.Logout)
v1GroupUser.GET("/list", mids.JwtAuth(), handler.ListUsers)
}
}
}
3.5 处理器(视图函数)
在上面的路由类中,处理器函数,类似flask的视图函数,封装在一个结构体中,也可以只是单纯的函数。
先定义一个用于规范响应体格式的结构体,以及一个健康检测的API,代码文件路径:internal/handlers/common.go
package handlers
import (
"net/http"
"github.com/gin-gonic/gin"
)
type jsonResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]any `json:"data"`
}
func NewJsonResponse(code int, message string, data map[string]any) *jsonResponse {
return &jsonResponse{
Code: code,
Message: message,
Data: data,
}
}
func GetHealth(c *gin.Context) {
c.JSON(http.StatusOK, NewJsonResponse(http.StatusOK, "OK", nil))
}
然后是具体的业务视图函数,代码文件:internal/handlers/users/users.go
package users
import (
"net/http"
"github.com/gin-gonic/gin"
"gin-hello/internal/dao"
"gin-hello/internal/handlers"
"gin-hello/pkg/log"
"gin-hello/pkg/token"
)
type userHandler struct{}
func NewUserHandler() *userHandler {
return &userHandler{}
}
func (u *userHandler) Register(c *gin.Context) {
var reqbody struct {
Username string `json:"username" binding:"required,min=6"`
Password string `json:"password" binding:"required,min=8,max=16"`
}
if err := c.ShouldBindBodyWithJSON(&reqbody); err != nil {
c.JSON(http.StatusBadRequest, handlers.NewJsonResponse(http.StatusBadRequest, "Invalid request body", map[string]any{"error": err.Error()}))
log.Error("Invalid request body", "error", err)
return
}
log.Info("Register user", "username", reqbody.Username, "password", reqbody.Password) // 生产环境中不要在日志中输入用户的敏感信息
userdao := dao.NewUserDao()
uid, err := userdao.RegisterUser(reqbody.Username, reqbody.Password)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.NewJsonResponse(http.StatusBadRequest, "Register failed", map[string]any{"error": err.Error()}))
log.Error("Register failed", "error", err)
return
}
c.JSON(200, handlers.NewJsonResponse(http.StatusOK, "Register Successfully", map[string]any{"uid": uid}))
}
func (u *userHandler) Login(c *gin.Context) {
var reqbody struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
if err := c.ShouldBindBodyWithJSON(&reqbody); err != nil {
log.Error("Invalid request body", "error", err)
c.JSON(http.StatusBadRequest, handlers.NewJsonResponse(http.StatusBadRequest, "Invalid request body", map[string]any{"error": err.Error()}))
return
}
userdao := dao.NewUserDao()
uid, err := userdao.Login(reqbody.Username, reqbody.Password)
if err != nil {
log.Error("Login failed", "error", err)
c.JSON(http.StatusBadRequest, handlers.NewJsonResponse(http.StatusBadRequest, "Login failed", map[string]any{"error": err.Error()}))
return
}
token, err := token.GenerateToken(uid)
if err != nil {
log.Error("Error generating token", "error", err)
c.JSON(http.StatusInternalServerError, handlers.NewJsonResponse(http.StatusInternalServerError, "Error generating token", map[string]any{"error": err.Error()}))
return
}
c.JSON(200, handlers.NewJsonResponse(http.StatusOK, "Login Successfully", map[string]any{
"uid": uid,
"token": token,
}))
}
func (u *userHandler) Logout(c *gin.Context) {
c.JSON(200, handlers.NewJsonResponse(http.StatusOK, "This is logout api", nil))
}
func (u *userHandler) ListUsers(c *gin.Context) {
c.JSON(200, handlers.NewJsonResponse(http.StatusOK, "This is list users api", nil))
}
主要实现了Login和Register方法,Logout和ListUsers方法没有实现具体功能。
3.6 数据访问类
数据访问类用于操作数据库中的数据,代码文件为internal/dao/user.go
package dao
import (
"fmt"
"gin-hello/models"
"gin-hello/pkg/db"
"gin-hello/pkg/log"
"gin-hello/pkg/utils"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
type userDao struct{}
func NewUserDao() *userDao {
return &userDao{}
}
func (u *userDao) rollback(tx *gorm.DB) {
if r := recover(); r != nil {
tx.Rollback()
}
}
func (u *userDao) InitAdmin(enable bool) error {
if !enable {
return nil
}
_, err := u.GetUserByUsername("admin")
if err == nil {
return nil
}
rdb := db.GetDB()
tx := rdb.Begin()
defer u.rollback(tx)
user := models.User{Username: "admin", Password: "admin"}
if err := tx.Create(&user).Error; err != nil {
tx.Rollback()
return err
}
if err := tx.Commit().Error; err != nil {
tx.Rollback()
return err
}
log.Info("Admin user created successfully")
return nil
}
func (u *userDao) GetUserByUsername(username string) (models.User, error) {
rdb := db.GetDB()
var user models.User
if err := rdb.Where("username = ?", username).First(&user).Error; err != nil {
return models.User{}, err
}
return user, nil
}
func (u *userDao) GetUserById(id uint) (models.User, error) {
rdb := db.GetDB()
var user models.User
if err := rdb.Where("id = ?", id).First(&user).Error; err != nil {
return models.User{}, err
}
return user, nil
}
func (u *userDao) RegisterUser(username, password string) (uint, error) {
_, err := u.GetUserByUsername(username)
if err == nil {
return 0, fmt.Errorf("username %s already exists", username)
}
// Using transactions to maintain data consistency
rdb := db.GetDB()
tx := rdb.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
user := models.User{Username: username, Password: password}
if err := tx.Create(&user).Error; err != nil {
tx.Rollback()
return 0, err
}
// commit the transaction
if err := tx.Commit().Error; err != nil {
return 0, err
}
newUser, err := u.GetUserByUsername(username)
return newUser.ID, err
}
func (u *userDao) Login(username, password string) (uint, error) {
user, err := u.GetUserByUsername(username)
if err != nil {
log.Error("User not found", "error", err)
return 0, err
}
if err := utils.VerifyPassword(password, user.Password); err != nil {
log.Error("Invalid password", "error", err)
return 0, err
}
return user.ID, nil
}
func (u *userDao) UpdatePassword(uid uint, newPassword string) error {
user, err := u.GetUserById(uid)
if err != nil {
log.Error("User not found", "error", err)
return err
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
log.Error("Error hashing password", "error", err)
return err
}
user.Password = string(hashedPassword)
rdb := db.GetDB()
tx := rdb.Begin()
defer u.rollback(tx)
if err := tx.Save(&user).Error; err != nil {
log.Error("Error updating password", "error", err)
tx.Rollback()
return err
}
if err := tx.Commit().Error; err != nil {
log.Error("Error committing transaction", "error", err)
return err
}
return nil
}
3.6 main
最后就是main了,主要就是 加载路由,初始化数据,然后监听服务。
package main
import (
"flag"
"fmt"
"gin-hello/internal/dao"
"gin-hello/internal/router"
"gin-hello/models"
"gin-hello/pkg/db"
)
var (
host string
port int
)
func main() {
flag.StringVar(&host, "host", "127.0.0.1", "host")
flag.IntVar(&port, "port", 8000, "port")
flag.Parse()
r := router.NewRouter()
// models.DropMigrateTables()
if err := db.InitDB(); err != nil {
panic(fmt.Errorf("init db failed, err: %v", err))
}
if err := models.MigrateTables(); err != nil {
panic(fmt.Errorf("migrate tables failed, err: %v", err))
}
userdao := dao.NewUserDao()
userdao.InitAdmin(true)
address := fmt.Sprintf("%s:%d", host, port)
if err := r.Run(address); err != nil {
panic(fmt.Errorf("run server failed, err: %v", err))
}
}
3.7 运行测试
- 编译并运行
go mod tidy
go build
./gin-hello
- 测试。这里用python脚本做的测试,也可以用其它http客户端工具。执行如下python脚本没问题的话,基于jwt的服务端开发就算基本完成了。
import asyncio
import httpx
base_url = "http://127.0.0.1:8000"
class UserApi:
def __init__(self):
self.token = ""
async def get_health(self):
async with httpx.AsyncClient(timeout=1) as client:
resp = await client.get(f"{base_url}/health")
print(f"response code: {resp.status_code}, response content: {resp.text}")
async def login(self):
async with httpx.AsyncClient(timeout=1) as client:
headers = {"Content-Type": "application/json"}
data = {
"username": "admin",
"password": "admin",
}
resp = await client.post(f"{base_url}/api/v1/user/login", json=data, headers=headers)
print(f"response code: {resp.status_code}, response content: {resp.text}")
self.token = resp.json()["data"]["token"]
async def user_lsit(self):
async with httpx.AsyncClient(timeout=1) as client:
headers = {"Content-Type": "application/json", "X-Token": f"{self.token}"}
print(headers)
resp = await client.get(f"{base_url}/api/v1/user/list", headers=headers)
print(f"response code: {resp.status_code}, response content: {resp.text}")
async def async_main(self):
await self.get_health()
await self.login()
await self.user_lsit()
if __name__ == "__main__":
user_api = UserApi()
asyncio.run(user_api.async_main())