跳到主要内容

认证鉴权

前言

以个人目前的工作经验来看,web应用底层架构中最折腾的就是认证鉴权部分,动不动就要对接啥啥啥玩意,出了安全篓子的时候第一个被挂出来批斗。如果公司内部有安全团队,隔三差五还得配合抗压。技术方面还要熟悉各种加密算法、认证方案、鉴权方案……

认证鉴权部分一般是中间件来做的,之前介绍中间件的时候跳过了,这里单独拎出来说,但也不会说太多,就简单介绍下基本的认证鉴权方案。

BasicAuth

Gin框架内置了gin.BasicAuth()中间件来处理BasicAuth认证。

服务端示例

package main

import (
"net/http"

"github.com/gin-gonic/gin"
)

func main() {
r := gin.Default()

// 在中间件中定义用户名和密码
authorized := r.Group("/admin", gin.BasicAuth(gin.Accounts{
"foo": "bar",
"austin": "1234",
"lena": "hello2",
"manu": "4321",
}))

authorized.GET("/secrets", func(c *gin.Context) {
// get user, it was set by the BasicAuth middleware
user := c.MustGet(gin.AuthUserKey).(string)
c.JSON(http.StatusOK, gin.H{
"message": "Hello " + user,
})
})

r.Run(":8000")
}

如果不带用户名密码请求/admin/secrets时将会收到401的响应码

$ curl http://127.0.0.1:8000/admin/secrets -v
* Trying 127.0.0.1:8000...
* Connected to 127.0.0.1 (127.0.0.1) port 8000 (#0)
> GET /admin/secrets HTTP/1.1
> Host: 127.0.0.1:8000
> User-Agent: curl/7.88.1
> Accept: */*
>
< HTTP/1.1 401 Unauthorized
< Www-Authenticate: Basic realm="Authorization Required"
< Date: Mon, 25 Nov 2024 16:24:10 GMT
< Content-Length: 0
<
* Connection #0 to host 127.0.0.1 left intact

使用用户名密码的请求

$ curl -u "foo:bar" http://127.0.0.1:8000/admin/secrets -v
* Trying 127.0.0.1:8000...
* Connected to 127.0.0.1 (127.0.0.1) port 8000 (#0)
* Server auth using Basic with user 'foo'
> GET /admin/secrets HTTP/1.1
> Host: 127.0.0.1:8000
> Authorization: Basic Zm9vOmJhcg==
> User-Agent: curl/7.88.1
> Accept: */*
>
< HTTP/1.1 200 OK
< Content-Type: application/json; charset=utf-8
< Date: Mon, 25 Nov 2024 16:26:38 GMT
< Content-Length: 23
<
* Connection #0 to host 127.0.0.1 left intact
{"message":"Hello foo"}

如果需要更复杂的认证逻辑(例如从数据库中动态验证用户名和密码),可以自定义中间件

func BasicAuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
username, password, hasAuth := c.Request.BasicAuth()
if !hasAuth || !validateUser(username, password) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
c.Set(gin.AuthUserKey, username)
c.Next()
}
}

// 验证用户名和密码
func validateUser(username, password string) bool {
// 示例:检查用户名和密码是否匹配
return username == "admin" && password == "password123"
}

func main() {
r := gin.Default()

// 自定义 Basic Auth 中间件
authorized := r.Group("/", BasicAuthMiddleware())

authorized.GET("/secure", func(c *gin.Context) {
user := c.MustGet(gin.AuthUserKey).(string)

c.JSON(http.StatusOK, gin.H{
"message": "Welcome " + user,
})
})

r.Run(":8080")
}

API Key

API Key指的是客户端通过一个唯一密钥来标识和验证自己,一般根据请求头中的某个参数鉴权。一旦泄露就会造成安全风险,一般只在内部服务使用。

package main

import (
"net/http"

"github.com/gin-gonic/gin"
)

func APIKeyAuthMiddleware(requiredKey string) gin.HandlerFunc {
return func(c *gin.Context) {
// 从 HTTP 请求头中提取 API Key
apiKey := c.GetHeader("X-API-KEY")

if apiKey == "" || apiKey != requiredKey {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}

c.Next()
}
}

func main() {
r := gin.Default()

var validAPIKey string = "qwer1234zxcv"

authorized := r.Group("/admin", APIKeyAuthMiddleware(validAPIKey))

authorized.GET("/secrets", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"message": "Hello World",
})
})

r.Run(":8000")
}

请求测试

curl -H 'X-API-KEY: qwer1234zxcv' http://127.0.0.1:8000/admin/secrets -v

JWT

上面两种方案相当于每个请求都进行了登录操作,极大增加了登录信息的泄露风险。目前主流做法是,在第一次登录之后产生存在有效期的Token,并将它存储在客户端某个地方,比如浏览器的cookie或localstorage。之后的请求都携带这个Token,在请求到达服务端后,由服务端用这个Token对请求进行认证。因为JWT是无状态的,服务端可以不存储jwt的token。

JWT,全称 JSON Web Token,是一种开放标准(RFC 7519),用于安全地在双方之间传递信息。尤其适用于身份验证和授权场景。JWT 的设计允许信息在各方之间安全地、 compactly(紧凑地)传输,因为其自身包含了所有需要的认证信息,从而减少了需要查询数据库或会话存储的需求。

JWT主要由三部分组成,通过 . 连接:

  1. Header(头部):描述JWT的元数据,通常包括类型(通常是JWT)和使用的签名算法(如HS256RS256等)。
  2. Payload(载荷):包含声明(claims),即用户的相关信息。这些信息可以是公开的,也可以是私有的,但应避免放入敏感信息,因为该部分可以被解码查看。载荷中的声明可以验证,但不加密。
  3. Signature(签名):用于验证JWT的完整性和来源。它是通过将Header和Payload分别进行Base64编码后,再与一个秘钥(secret)一起通过指定的算法(如HMAC SHA256)计算得出的。

JWT认证流程

基本jwt认证需要实现以下功能

  • API:用户登录、需要认证请求的API
  • 数据库:存储用户账户信息
  • 生成jwt token的方法
  • 中间件:通过中间件校验请求

除了基础功能,一般还需要用户登出功能,让token立即失效。最简单的做法是客户端删掉token,但如果token提前保存了,这个token在失效前还是可以继续用的。还有种做法是将登出的token拉进黑名单,鉴权的时候先判断token在不在黑名单里面。单个实例的话还能处理,多实例的话就要考虑黑名单数据同步的问题了,可能就需要引入redis这样的键值对数据库。既然引入了redis,而且每次都要先查一遍redis,还不如直接存token,每次鉴权先检查token在不在redis,这样jwt就变成有状态了,跟session没什么区别。

如果考虑jwt token泄露导致的安全风险,首先token的过期时间不能太久。但如果token的生效时间太短,用户可能需要频繁重新登录,用户体验会比较差。有种做法是双Token机制,用户登录的时候,服务端返回一个长期的refresh token和一个短期的access token。refresh token的过期时间更长,比如1天,而access token的过期时间可能只有15分钟。客户端每次请求后端资源的时候携带的是access token,如果access token失效,则用refresh token去获取一个新的access token,然后再用新的access token继续请求后端。refresh token只用来获取新的access token,而不涉及请求具体资源。后端可以将refresh token放到redis等键值对数据库,当用户登出时,后端从redis中清除refresh token,这样客户端就获取不到新的access token。这种双token的设计,一定程度上降低了access token泄露的风险。如果refresh token也泄露了,后端也可以从redis中直接清除refresh token。双token没法保证没有安全问题,只能说是在性能、体验、安全各方面都有一定保障。

以下示例代码仅简单演示用户登录获取jwt token,然后用jwt token访问资源的过程。描述虽然就两句话,但涉及数据库连接、jwt token生成和验证、中间件和用户登录功能。

3.0 代码文件目录

这里代码文件结构参考了一些其他项目,然后又按照个人习惯来组织的。go没有固定的文件结构需要去遵守,按个人喜好即可。

├── conf
│   └── config.yaml
├── go.mod
├── go.sum
├── internal
│   ├── dao
│   │   └── user.go
│   ├── handlers
│   │   ├── common.go
│   │   └── users
│   │   └── users.go
│   └── router
│   └── router.go
├── LICENSE
├── main.go
├── models
│   ├── models.go
│   └── users.go
├── pkg
│   ├── config
│   │   └── config.go
│   ├── db
│   │   ├── rdb.go
│   │   └── redis.go
│   ├── log
│   │   └── log.go
│   ├── middlewares
│   │   └── mids.go
│   ├── token
│   │   ├── token.go
│   │   └── token_test.go
│   └── utils
│   ├── utils.go
│   └── utils_test.go
└── README.md

3.1 准备

  1. 准备数据库。后续代码用了gorm,理论上来说,sqlite、postgres或mysql都行,这里用的postgres,安装方法可以在网上找找步骤,最简单的方式就是用docker了。用包管理器或源码编译的方式也行。
  2. 创建一个新的应用目录。基于之前代码修修改改也行
mkdir gin-hello
cd gin-hello
go mod init gin-hello
  1. 编辑配置文件。配置文件放在conf目录
log:
logdir: logs
retention: 7 # 单位: 天
console: true

store:
rdb:
dbtype: postgres # mysql | sqlite | postgres
host: 192.168.0.201
port: 5432
username: app
password: app
dbname: app
dbfile: app.sqlite3 # used for sqlite
redis:
enabled: true
addr: 192.168.0.201:6379
username: ""
password: ""
dbnum: 0

jwt:
secrets: "qwertyuiop"
expire_seconds: 7200

3.2 编写基础功能模块

以下写了:配置模块,用于读取配置文件;日志模块;数据库连接模块;token相关模块;中间件模块;工具模块。

3.2.1 配置模块

按个人习惯,代码文件路径为pkg/config/config.go。这里用viper来读取配置文件,用其它库来读取也是可以的。有的项目会单独写个Config结构体,然后为这个结构体实现各种方法,用的时候实例化一个实例即可。

package config

import (
"log/slog"
"os"
"os/signal"
"syscall"

"github.com/spf13/viper"
)

var v *viper.Viper

func init() {
v = viper.New()
// v.AddConfigPath("conf") // 配置文件的目录
v.AddConfigPath("/home/rainux/workspace/go-learn/gin-hello/conf") // 配置文件的目录
v.SetConfigType("yaml") // 配置文件的类型
v.SetConfigName("config")

err := v.ReadInConfig()
if err != nil {
slog.Error("Error reading config file", "error", err)
panic(err)
}
reloadWithSighup()
}

// 监听SIGHUP信号,用于热加载
func reloadWithSighup() {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGHUP)

go func() {
for range sigs {
slog.Info("SIGHUB signal received, reloading config...")
if err := v.ReadInConfig(); err != nil {
slog.Error("Error reloading config file", "error", err)
}
slog.Info("Reload config successfully")
}
}()
}

type webaddr struct {
Host string
Port int
}

func GetWebAddr() webaddr {
host := v.GetString("web.host")
port := v.GetInt("web.port")
return webaddr{Host: host, Port: port}
}

func GetGinMode() string {
slog.Info("gin mode: " + v.GetString("web.mode"))
return v.GetString("web.mode")
}

type logcfg struct {
LogDir string
Retention int
Console bool
}

func GetLogCfg() logcfg {
logDir := v.GetString("log.logdir")
retention := v.GetInt("log.retention")
console := v.GetBool("log.console")
return logcfg{LogDir: logDir, Retention: retention, Console: console}
}

type rdbcfg struct {
Dbtype string
Host string
Port int
Username string
Password string
Dbname string
Dbfile string
}

func GetRdbCfg() rdbcfg {
dbtype := v.GetString("store.rdb.dbtype")
host := v.GetString("store.rdb.host")
port := v.GetInt("store.rdb.port")
username := v.GetString("store.rdb.username")
password := v.GetString("store.rdb.password")
dbname := v.GetString("store.rdb.dbname")
dbfile := v.GetString("store.rdb.dbfile")
return rdbcfg{Dbtype: dbtype, Host: host, Port: port, Username: username, Password: password, Dbname: dbname, Dbfile: dbfile}
}

type rediscfg struct {
Enabled bool
Addr string
Username string
Password string
Dbnum int
}

func GetRedisCfg() rediscfg {
return rediscfg{
Enabled: v.GetBool("store.redis.enabled"),
Addr: v.GetString("store.redis.addr"),
Username: v.GetString("store.redis.username"),
Password: v.GetString("store.redis.password"),
Dbnum: v.GetInt("store.redis.dbnum"),
}
}

type jwtcfg struct {
Secrets string
Expire int
}

func GetJwtCfg() jwtcfg {
return jwtcfg{
Secrets: v.GetString("jwt.secrets"),
Expire: v.GetInt("jwt.expire_seconds"),
}
}

3.2.2 数据库连接模块

如果不是追求特别高性能,建议还是用orm,一方面是方便迁移后端数据库类型,另一方面避免sql注入问题。这里用的是gorm,代码文件路径为pkg/db/rdb.go。下面代码支持使用postgresmysqlsqlite,便于切换各种数据库。

package db

import (
"fmt"
"gin-hello/pkg/config"
"strings"

"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)

var (
db *gorm.DB
err error
)

func InitDB() error {
cfg := config.GetRdbCfg()
if cfg.Dbtype != "sqlite" {
if cfg.Host == "" || cfg.Dbname == "" || cfg.Username == "" || cfg.Password == "" {
return fmt.Errorf("invalid database configuration")
}
}
switch strings.ToLower(cfg.Dbtype) {
case "mysql":
dsn := getDSNMysql(cfg.Host, cfg.Port, cfg.Username, cfg.Password, cfg.Dbname)
db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{})
if err != nil {
return fmt.Errorf("db mysql connect error %s", err)
}
case "postgres":
dsn := getDSNPG(cfg.Host, cfg.Port, cfg.Username, cfg.Password, cfg.Dbname)
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return fmt.Errorf("db postgres connect error %s", err)
}
case "sqlite":
db, err = gorm.Open(sqlite.Open(cfg.Dbfile), &gorm.Config{})
if err != nil {
return fmt.Errorf("db sqlite connect error %s", err)
}
default:
return fmt.Errorf("unsupport dbtype %s", cfg.Dbtype)
}
sqlDB, _ := db.DB()
sqlDB.SetMaxIdleConns(10)
sqlDB.SetMaxOpenConns(20)
return nil
}

// GetDB returns the database instance
func GetDB() *gorm.DB {
return db
}

func getDSNMysql(host string, port int, user string, password string, dbname string) string {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", user, password, host, port, dbname)
}

func getDSNPG(host string, port int, user string, password string, dbname string) string {
return fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable TimeZone=Asia/Shanghai", host, user, password, dbname, port)
}

3.2.3 日志模块

这里用的是slog,go 1.22版本后自带的高性能日志库。代码文件路径:pkg/log/log.go

package log

import (
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"time"

"gopkg.in/natefinch/lumberjack.v2"

"gin-hello/pkg/config"
)

var (
logger *slog.Logger
logFile string = "app.log"
logDir string
logRetention int
)

func init() {
logcfg := config.GetLogCfg()
logDir = logcfg.LogDir
logRetention = logcfg.Retention
if _, err := os.Stat(logDir); os.IsNotExist(err) { // 创建日志目录
err := os.Mkdir(logDir, os.ModePerm)
if err != nil {
fmt.Printf("Create log directory '%s' failed\n", logDir)
panic(err)
}
}
logOpts := &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelInfo,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { // 日志时间格式化
if a.Key == slog.TimeKey {
t, ok := a.Value.Any().(time.Time)
if !ok {
return a
}
return slog.String(slog.TimeKey, t.Format("2006-01-02 15:04:05"))
}
return a
},
}

rotateLogger := getRorateLogger()
var multiWriter io.Writer
if logcfg.Console {
multiWriter = io.MultiWriter(rotateLogger)
} else {
multiWriter = io.MultiWriter(os.Stdout, rotateLogger)
}

fileHandler := slog.NewJSONHandler(multiWriter, logOpts)
logger = slog.New(fileHandler)
}

// 配置日志自动切割
func getRorateLogger() *lumberjack.Logger {
logFilePath := filepath.Join(logDir, logFile)
return &lumberjack.Logger{
Filename: logFilePath,
MaxAge: logRetention,
Compress: true,
}
}

func Debug(msg string, args ...any) {
logger.Debug(msg, args...)
}

func Info(msg string, args ...any) {
logger.Info(msg, args...)
}

func Error(msg string, args ...any) {
logger.Error(msg, args...)
}

func Warn(msg string, args ...any) {
logger.Warn(msg, args...)
}

3.2.4 token功能模块

该模块主要用于生成jwt token和验证jwt token。代码文件路径为pkg/token/token.go

package token

import (
"fmt"
"gin-hello/pkg/config"
"time"

"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
)

func GenerateToken(user_id uint) (string, error) {
jwtcfg := config.GetJwtCfg()
claim := jwt.MapClaims{}
claim["uid"] = user_id
claim["iat"] = time.Now().Unix() // issue at
claim["exp"] = time.Now().Add(time.Second * time.Duration(jwtcfg.Expire)).Unix() // expire at
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claim)
return jwtToken.SignedString([]byte(jwtcfg.Secrets))
}

func VerifyToken(c *gin.Context) error {
tokenString := c.GetHeader("X-Token")
_, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
jwtcfg := config.GetJwtCfg()
return []byte(jwtcfg.Secrets), nil
})

if err != nil {
return err
}
return nil
}

3.2.5 中间件模块

这里只实现了一个认证中间件,代码文件路径:pkg/middlewares/mids.go

package middlewares

import (
"gin-hello/pkg/token"
"net/http"

"github.com/gin-gonic/gin"
)

type mids struct{}

func NewMids() *mids {
return &mids{}
}

func (m *mids) JwtAuth() gin.HandlerFunc {
return func(c *gin.Context) {
err := token.VerifyToken(c)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{
"code": http.StatusUnauthorized,
"message": "unauthorized",
"data": map[string]any{
"error": err.Error(),
},
})
c.Abort()
return
}
c.Next()
}
}

3.2.6 工具类模块

这个模块包括一些杂七杂八的工具,代码文件路径pkg/utils/utils.go

package utils

import (
"math/rand"
"time"

"golang.org/x/crypto/bcrypt"
)

func HashString(s string) (string, error) {
hashedString, err := bcrypt.GenerateFromPassword([]byte(s), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedString), nil
}

func VerifyPassword(password, hashedPassword string) error {
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
}

func GenerateShortUrl(length uint) string {
if length < 6 {
length = 6
}
var charset string = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
seed := rand.NewSource(time.Now().UnixNano())
r := rand.New(seed)
result := make([]byte, length)
for i := range result {
result[i] = charset[r.Intn(len(charset))]
}
return string(result)
}

3.3 模型类

模型类用于映射数据表,代码文件路径为models/users.go

package models

import (
"html"
"strings"

"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)

type User struct {
gorm.Model
Username string `gorm:"size:255;not null;unique" json:"username"`
Password string `gorm:"size:255;not null;" json:"password"`
}

func (u *User) TableName() string {
return "users"
}

// 使用gorm的hook在新增用户数据前对密码进行hash加密
func (u *User) BeforeSave(tx *gorm.DB) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
if err != nil {
return err
}
u.Password = string(hashedPassword)
u.Username = html.EscapeString(strings.TrimSpace(u.Username))
return nil
}

添加自动创建表的函数,如果表结构要自定义,可跳过Migrate的函数。代码文件路径:models/models.go

package models

import "gin-hello/pkg/db"

func MigrateTables() error {
rdb := db.GetDB()
err := rdb.AutoMigrate(&User{})
return err
}

func DropMigrateTables() error {
rdb := db.GetDB()
err := rdb.Migrator().DropTable(&User{})
return err
}

3.4 路由类

路由类用于将路由绑定到路由函数,代码文件为internal/router/router.go

package router

import (
"gin-hello/internal/handlers"
"gin-hello/internal/handlers/users"
"gin-hello/pkg/middlewares"

"github.com/gin-gonic/gin"
)

var mids = middlewares.NewMids()

func NewRouter() *gin.Engine {
r := gin.Default()
r.GET("/health", handlers.GetHealth)

apiGroup := r.Group("/api")
loadUserHandler(apiGroup)

return r
}

func loadUserHandler(r *gin.RouterGroup) {
handler := users.NewUserHandler()
v1Group := r.Group("/v1")
{
v1GroupUser := v1Group.Group("/user")
{
v1GroupUser.POST("/register", handler.Register)
v1GroupUser.POST("/login", handler.Login)
v1GroupUser.POST("/logout", handler.Logout)
v1GroupUser.GET("/list", mids.JwtAuth(), handler.ListUsers)
}
}
}

3.5 处理器(视图函数)

在上面的路由类中,处理器函数,类似flask的视图函数,封装在一个结构体中,也可以只是单纯的函数。

先定义一个用于规范响应体格式的结构体,以及一个健康检测的API,代码文件路径:internal/handlers/common.go

package handlers

import (
"net/http"

"github.com/gin-gonic/gin"
)

type jsonResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data map[string]any `json:"data"`
}

func NewJsonResponse(code int, message string, data map[string]any) *jsonResponse {
return &jsonResponse{
Code: code,
Message: message,
Data: data,
}
}

func GetHealth(c *gin.Context) {
c.JSON(http.StatusOK, NewJsonResponse(http.StatusOK, "OK", nil))
}

然后是具体的业务视图函数,代码文件:internal/handlers/users/users.go

package users

import (
"net/http"

"github.com/gin-gonic/gin"

"gin-hello/internal/dao"
"gin-hello/internal/handlers"
"gin-hello/pkg/log"
"gin-hello/pkg/token"
)

type userHandler struct{}

func NewUserHandler() *userHandler {
return &userHandler{}
}

func (u *userHandler) Register(c *gin.Context) {
var reqbody struct {
Username string `json:"username" binding:"required,min=6"`
Password string `json:"password" binding:"required,min=8,max=16"`
}
if err := c.ShouldBindBodyWithJSON(&reqbody); err != nil {
c.JSON(http.StatusBadRequest, handlers.NewJsonResponse(http.StatusBadRequest, "Invalid request body", map[string]any{"error": err.Error()}))
log.Error("Invalid request body", "error", err)
return
}
log.Info("Register user", "username", reqbody.Username, "password", reqbody.Password) // 生产环境中不要在日志中输入用户的敏感信息
userdao := dao.NewUserDao()
uid, err := userdao.RegisterUser(reqbody.Username, reqbody.Password)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.NewJsonResponse(http.StatusBadRequest, "Register failed", map[string]any{"error": err.Error()}))
log.Error("Register failed", "error", err)
return
}
c.JSON(200, handlers.NewJsonResponse(http.StatusOK, "Register Successfully", map[string]any{"uid": uid}))
}

func (u *userHandler) Login(c *gin.Context) {
var reqbody struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
if err := c.ShouldBindBodyWithJSON(&reqbody); err != nil {
log.Error("Invalid request body", "error", err)
c.JSON(http.StatusBadRequest, handlers.NewJsonResponse(http.StatusBadRequest, "Invalid request body", map[string]any{"error": err.Error()}))
return
}

userdao := dao.NewUserDao()
uid, err := userdao.Login(reqbody.Username, reqbody.Password)
if err != nil {
log.Error("Login failed", "error", err)
c.JSON(http.StatusBadRequest, handlers.NewJsonResponse(http.StatusBadRequest, "Login failed", map[string]any{"error": err.Error()}))
return
}
token, err := token.GenerateToken(uid)
if err != nil {
log.Error("Error generating token", "error", err)
c.JSON(http.StatusInternalServerError, handlers.NewJsonResponse(http.StatusInternalServerError, "Error generating token", map[string]any{"error": err.Error()}))
return
}

c.JSON(200, handlers.NewJsonResponse(http.StatusOK, "Login Successfully", map[string]any{
"uid": uid,
"token": token,
}))
}

func (u *userHandler) Logout(c *gin.Context) {
c.JSON(200, handlers.NewJsonResponse(http.StatusOK, "This is logout api", nil))
}

func (u *userHandler) ListUsers(c *gin.Context) {
c.JSON(200, handlers.NewJsonResponse(http.StatusOK, "This is list users api", nil))
}

主要实现了Login和Register方法,Logout和ListUsers方法没有实现具体功能。

3.6 数据访问类

数据访问类用于操作数据库中的数据,代码文件为internal/dao/user.go

package dao

import (
"fmt"
"gin-hello/models"
"gin-hello/pkg/db"
"gin-hello/pkg/log"
"gin-hello/pkg/utils"

"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)

type userDao struct{}

func NewUserDao() *userDao {
return &userDao{}
}

func (u *userDao) rollback(tx *gorm.DB) {
if r := recover(); r != nil {
tx.Rollback()
}
}

func (u *userDao) InitAdmin(enable bool) error {
if !enable {
return nil
}
_, err := u.GetUserByUsername("admin")
if err == nil {
return nil
}
rdb := db.GetDB()
tx := rdb.Begin()
defer u.rollback(tx)
user := models.User{Username: "admin", Password: "admin"}
if err := tx.Create(&user).Error; err != nil {
tx.Rollback()
return err
}
if err := tx.Commit().Error; err != nil {
tx.Rollback()
return err
}
log.Info("Admin user created successfully")
return nil
}

func (u *userDao) GetUserByUsername(username string) (models.User, error) {
rdb := db.GetDB()
var user models.User
if err := rdb.Where("username = ?", username).First(&user).Error; err != nil {
return models.User{}, err
}
return user, nil
}

func (u *userDao) GetUserById(id uint) (models.User, error) {
rdb := db.GetDB()
var user models.User
if err := rdb.Where("id = ?", id).First(&user).Error; err != nil {
return models.User{}, err
}
return user, nil
}

func (u *userDao) RegisterUser(username, password string) (uint, error) {
_, err := u.GetUserByUsername(username)
if err == nil {
return 0, fmt.Errorf("username %s already exists", username)
}

// Using transactions to maintain data consistency
rdb := db.GetDB()
tx := rdb.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
user := models.User{Username: username, Password: password}
if err := tx.Create(&user).Error; err != nil {
tx.Rollback()
return 0, err
}
// commit the transaction
if err := tx.Commit().Error; err != nil {
return 0, err
}

newUser, err := u.GetUserByUsername(username)
return newUser.ID, err
}

func (u *userDao) Login(username, password string) (uint, error) {
user, err := u.GetUserByUsername(username)
if err != nil {
log.Error("User not found", "error", err)
return 0, err
}
if err := utils.VerifyPassword(password, user.Password); err != nil {
log.Error("Invalid password", "error", err)
return 0, err
}
return user.ID, nil
}

func (u *userDao) UpdatePassword(uid uint, newPassword string) error {
user, err := u.GetUserById(uid)
if err != nil {
log.Error("User not found", "error", err)
return err
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
log.Error("Error hashing password", "error", err)
return err
}
user.Password = string(hashedPassword)
rdb := db.GetDB()
tx := rdb.Begin()
defer u.rollback(tx)
if err := tx.Save(&user).Error; err != nil {
log.Error("Error updating password", "error", err)
tx.Rollback()
return err
}
if err := tx.Commit().Error; err != nil {
log.Error("Error committing transaction", "error", err)
return err
}
return nil
}

3.6 main

最后就是main了,主要就是加载路由,初始化数据,然后监听服务。

package main

import (
"flag"
"fmt"
"gin-hello/internal/dao"
"gin-hello/internal/router"
"gin-hello/models"
"gin-hello/pkg/db"
)

var (
host string
port int
)

func main() {
flag.StringVar(&host, "host", "127.0.0.1", "host")
flag.IntVar(&port, "port", 8000, "port")
flag.Parse()
r := router.NewRouter()
// models.DropMigrateTables()
if err := db.InitDB(); err != nil {
panic(fmt.Errorf("init db failed, err: %v", err))
}
if err := models.MigrateTables(); err != nil {
panic(fmt.Errorf("migrate tables failed, err: %v", err))
}
userdao := dao.NewUserDao()
userdao.InitAdmin(true)
address := fmt.Sprintf("%s:%d", host, port)
if err := r.Run(address); err != nil {
panic(fmt.Errorf("run server failed, err: %v", err))
}
}

3.7 运行测试

  1. 编译并运行
go mod tidy
go build
./gin-hello
  1. 测试。这里用python脚本做的测试,也可以用其它http客户端工具。执行如下python脚本没问题的话,基于jwt的服务端开发就算基本完成了。
import asyncio
import httpx

base_url = "http://127.0.0.1:8000"

class UserApi:
def __init__(self):
self.token = ""

async def get_health(self):
async with httpx.AsyncClient(timeout=1) as client:
resp = await client.get(f"{base_url}/health")
print(f"response code: {resp.status_code}, response content: {resp.text}")

async def login(self):
async with httpx.AsyncClient(timeout=1) as client:
headers = {"Content-Type": "application/json"}
data = {
"username": "admin",
"password": "admin",
}
resp = await client.post(f"{base_url}/api/v1/user/login", json=data, headers=headers)
print(f"response code: {resp.status_code}, response content: {resp.text}")
self.token = resp.json()["data"]["token"]

async def user_lsit(self):
async with httpx.AsyncClient(timeout=1) as client:
headers = {"Content-Type": "application/json", "X-Token": f"{self.token}"}
print(headers)
resp = await client.get(f"{base_url}/api/v1/user/list", headers=headers)
print(f"response code: {resp.status_code}, response content: {resp.text}")

async def async_main(self):
await self.get_health()
await self.login()
await self.user_lsit()

if __name__ == "__main__":
user_api = UserApi()
asyncio.run(user_api.async_main())