Browse Source

Add JSSDK context method functionality (#828)

* Add JSSDK context method functionality

* 善JSSDK上下文方法,并添加测试文件

* feat: 完善JSSDK上下文方法,保证协程安全,并添加测试文件

* 修改 import 包分组处理

* feat: 修改测试文件中 fmt.Print -> t.Log

* 删除空行
lizhuang 1 năm trước cách đây
mục cha
commit
b639d2235d

+ 15 - 4
credential/default_js_ticket.go

@@ -1,6 +1,7 @@
 package credential
 
 import (
+	context2 "context"
 	"encoding/json"
 	"fmt"
 	"sync"
@@ -42,6 +43,16 @@ type ResTicket struct {
 
 // GetTicket 获取jsapi_ticket
 func (js *DefaultJsTicket) GetTicket(accessToken string) (ticketStr string, err error) {
+	return js.GetTicketContext(context2.Background(), accessToken)
+}
+
+// GetTicketFromServer 从服务器中获取ticket
+func GetTicketFromServer(accessToken string) (ticket ResTicket, err error) {
+	return GetTicketFromServerContext(context2.Background(), accessToken)
+}
+
+// GetTicketContext 获取jsapi_ticket
+func (js *DefaultJsTicket) GetTicketContext(ctx context2.Context, accessToken string) (ticketStr string, err error) {
 	// 先从cache中取
 	jsAPITicketCacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", js.cacheKeyPrefix, js.appID)
 	if val := js.cache.Get(jsAPITicketCacheKey); val != nil {
@@ -57,7 +68,7 @@ func (js *DefaultJsTicket) GetTicket(accessToken string) (ticketStr string, err
 	}
 
 	var ticket ResTicket
-	ticket, err = GetTicketFromServer(accessToken)
+	ticket, err = GetTicketFromServerContext(ctx, accessToken)
 	if err != nil {
 		return
 	}
@@ -67,11 +78,11 @@ func (js *DefaultJsTicket) GetTicket(accessToken string) (ticketStr string, err
 	return
 }
 
-// GetTicketFromServer 从服务器中获取ticket
-func GetTicketFromServer(accessToken string) (ticket ResTicket, err error) {
+// GetTicketFromServerContext 从服务器中获取ticket
+func GetTicketFromServerContext(ctx context2.Context, accessToken string) (ticket ResTicket, err error) {
 	var response []byte
 	url := fmt.Sprintf(getTicketURL, accessToken)
-	response, err = util.HTTPGet(url)
+	response, err = util.HTTPGetContext(ctx, url)
 	if err != nil {
 		return
 	}

+ 22 - 0
credential/default_js_ticket_test.go

@@ -0,0 +1,22 @@
+package credential
+
+import (
+	"context"
+	"fmt"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"gopkg.in/h2non/gock.v1"
+)
+
+// TestGetTicketFromServerContext 测试 GetTicketFromServerContext 函数
+func TestGetTicketFromServerContext(t *testing.T) {
+	defer gock.Off()
+	gock.New(fmt.Sprintf(getTicketURL, "arg-ak")).Reply(200).JSON(&ResTicket{Ticket: "mock-ticket", ExpiresIn: 10})
+
+	ticket, err := GetTicketFromServerContext(context.Background(), "arg-ak")
+	assert.Nil(t, err)
+	assert.Equal(t, int64(0), ticket.ErrCode)
+	assert.Equal(t, "mock-ticket", ticket.Ticket, "they should be equal")
+	assert.Equal(t, int64(10), ticket.ExpiresIn, "they should be equal")
+}

+ 8 - 0
credential/js_ticket.go

@@ -1,7 +1,15 @@
 package credential
 
+import context2 "context"
+
 // JsTicketHandle js ticket获取
 type JsTicketHandle interface {
 	// GetTicket 获取ticket
 	GetTicket(accessToken string) (ticket string, err error)
 }
+
+// JsTicketContextHandle js ticket获取
+type JsTicketContextHandle interface {
+	JsTicketHandle
+	GetTicketContext(ctx context2.Context, accessToken string) (ticket string, err error)
+}

+ 23 - 2
officialaccount/js/js.go

@@ -1,6 +1,7 @@
 package js
 
 import (
+	context2 "context"
 	"fmt"
 
 	"github.com/silenceper/wechat/v2/credential"
@@ -39,20 +40,40 @@ func (js *Js) SetJsTicketHandle(ticketHandle credential.JsTicketHandle) {
 // GetConfig 获取jssdk需要的配置参数
 // uri 为当前网页地址
 func (js *Js) GetConfig(uri string) (config *Config, err error) {
+	return js.GetConfigContext(context2.Background(), uri)
+}
+
+// GetConfigContext  新方法,允许传入上下文,避免协程泄漏
+func (js *Js) GetConfigContext(ctx context2.Context, uri string) (config *Config, err error) {
 	var accessToken string
-	accessToken, err = js.GetAccessToken()
+	// 类型断言,如果断言成功,调用安全的 GetAccessTokenContext 方法
+	if ctxHandle, ok := js.Context.AccessTokenHandle.(credential.AccessTokenContextHandle); ok {
+		accessToken, err = ctxHandle.GetAccessTokenContext(ctx)
+	} else {
+		// 如果没有实现 AccessTokenContextHandle 接口,调用旧的 GetAccessToken 方法
+		accessToken, err = js.Context.GetAccessToken()
+	}
 	if err != nil {
 		return
 	}
+
 	var ticketStr string
-	ticketStr, err = js.GetTicket(accessToken)
+	// 类型断言 jsTicket
+	if ticketCtxHandle, ok := js.JsTicketHandle.(credential.JsTicketContextHandle); ok {
+		ticketStr, err = ticketCtxHandle.GetTicketContext(ctx, accessToken)
+	} else {
+		// 如果没有实现 JsTicketContextHandle 接口,调用旧的 GetTicket 方法
+		ticketStr, err = js.GetTicket(accessToken)
+	}
 	if err != nil {
 		return
 	}
+
 	nonceStr := util.RandomStr(16)
 	timestamp := util.GetCurrTS()
 	str := fmt.Sprintf("jsapi_ticket=%s&noncestr=%s&timestamp=%d&url=%s", ticketStr, nonceStr, timestamp, uri)
 	sigStr := util.Signature(str)
+
 	config = new(Config)
 	config.AppID = js.AppID
 	config.NonceStr = nonceStr

+ 22 - 3
openplatform/officialaccount/js/js.go

@@ -1,6 +1,7 @@
 package js
 
 import (
+	context2 "context"
 	"fmt"
 
 	"github.com/silenceper/wechat/v2/credential"
@@ -32,14 +33,31 @@ func (js *Js) SetJsTicketHandle(ticketHandle credential.JsTicketHandle) {
 // GetConfig 第三方平台 - 获取jssdk需要的配置参数
 // uri 为当前网页地址
 func (js *Js) GetConfig(uri, appid string) (config *officialJs.Config, err error) {
-	config = new(officialJs.Config)
+	return js.GetConfigContext(context2.Background(), uri, appid)
+}
+
+// GetConfigContext 新方法,允许传入上下文,避免协程泄漏
+func (js *Js) GetConfigContext(ctx context2.Context, uri, appid string) (config *officialJs.Config, err error) {
 	var accessToken string
-	accessToken, err = js.GetAccessToken()
+	// 类型断言,如果断言成功,调用安全的 GetAccessTokenContext 方法
+	if ctxHandle, ok := js.Context.AccessTokenHandle.(credential.AccessTokenContextHandle); ok {
+		accessToken, err = ctxHandle.GetAccessTokenContext(ctx)
+	} else {
+		// 如果没有实现 AccessTokenContextHandle 接口,调用旧的 GetAccessToken 方法
+		accessToken, err = js.Context.GetAccessToken()
+	}
 	if err != nil {
 		return
 	}
+
 	var ticketStr string
-	ticketStr, err = js.GetTicket(accessToken)
+	// 类型断言 jsTicket
+	if ticketCtxHandle, ok := js.JsTicketHandle.(credential.JsTicketContextHandle); ok {
+		ticketStr, err = ticketCtxHandle.GetTicketContext(ctx, accessToken)
+	} else {
+		// 如果没有实现 JsTicketContextHandle 接口,调用旧的 GetTicket 方法
+		ticketStr, err = js.GetTicket(accessToken)
+	}
 	if err != nil {
 		return
 	}
@@ -49,6 +67,7 @@ func (js *Js) GetConfig(uri, appid string) (config *officialJs.Config, err error
 	str := fmt.Sprintf("jsapi_ticket=%s&noncestr=%s&timestamp=%d&url=%s", ticketStr, nonceStr, timestamp, uri)
 	sigStr := util.Signature(str)
 
+	config = new(officialJs.Config)
 	config.AppID = appid
 	config.NonceStr = nonceStr
 	config.Timestamp = timestamp

+ 147 - 0
openplatform/officialaccount/js/js_test.go

@@ -0,0 +1,147 @@
+// 验证 js.GetConfigContext 是否能正确传递上下文到 HTTP 请求,确保上下文正确传播,防止在获取 JSSDK 配置时发生协程泄露。
+package js
+
+import (
+	"bytes"
+	context2 "context"
+	"errors"
+	"fmt"
+	"io"
+	"net/http"
+	"testing"
+
+	"github.com/silenceper/wechat/v2/cache"
+	"github.com/silenceper/wechat/v2/credential"
+	"github.com/silenceper/wechat/v2/officialaccount/config"
+	"github.com/silenceper/wechat/v2/officialaccount/context"
+	"github.com/silenceper/wechat/v2/util"
+)
+
+// mockAccessTokenHandle 模拟 AccessTokenHandle
+type mockAccessTokenHandle struct{}
+
+func (m *mockAccessTokenHandle) GetAccessToken() (string, error) {
+	return "mock-access-token", nil
+}
+
+func (m *mockAccessTokenHandle) GetAccessTokenContext(_ context2.Context) (string, error) {
+	return "mock-access-token", nil
+}
+
+// contextCheckingRoundTripper 自定义 RoundTripper 用于检查 context
+type contextCheckingRoundTripper struct {
+	originalCtx context2.Context
+	t           *testing.T
+	key         interface{}
+	expectedVal interface{}
+}
+
+func (rt *contextCheckingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+	// 获取请求中的 context
+	reqCtx := req.Context()
+
+	// 打印 context 比较结果
+	rt.t.Logf("比较上下文的内存地址:\n")
+	if reqCtx == rt.originalCtx {
+		rt.t.Logf("上下文具有相同的内存地址。原始上下文: %p, 请求上下文: %p\n", rt.originalCtx, reqCtx)
+	} else {
+		rt.t.Logf("上下文具有不同的内存地址。原始上下文: %p, 请求上下文: %p\n", rt.originalCtx, reqCtx)
+	}
+
+	// 检查 context 中的键值对
+	if rt.key != nil {
+		value := reqCtx.Value(rt.key)
+		rt.t.Logf("检查请求上下文中的键 %v:\n", rt.key)
+		if value != rt.expectedVal {
+			rt.t.Errorf("上下文键 %v 的值不匹配: 预期 %v, 实际 %v\n", rt.key, rt.expectedVal, value)
+		} else {
+			rt.t.Logf("上下文键 %v 的值匹配: 预期 %v, 实际 %v\n", rt.key, rt.expectedVal, value)
+		}
+	}
+
+	// 检查上下文是否已取消
+	select {
+	case <-reqCtx.Done():
+		return nil, reqCtx.Err() // 返回上下文取消错误
+	default:
+		// 返回模拟的 HTTP 响应,包含有效的 JSON
+		responseBody := `{"ticket":"mock-ticket","expires_in":7200}`
+		response := &http.Response{
+			Status:        "200 OK",
+			StatusCode:    http.StatusOK,
+			Proto:         "HTTP/1.1",
+			ProtoMajor:    1,
+			ProtoMinor:    1,
+			Body:          io.NopCloser(bytes.NewReader([]byte(responseBody))),
+			ContentLength: int64(len(responseBody)),
+			Header:        make(http.Header),
+		}
+		response.Header.Set("Content-Type", "application/json")
+		return response, nil
+	}
+}
+
+// contextKey 定义自定义上下文键类型,避免使用内置 string 类型
+type contextKey string
+
+// setupJsInstance 初始化 Js 实例和 HTTP 客户端
+func setupJsInstance(t *testing.T, ctx context2.Context, key, val interface{}) (*Js, func()) {
+	cfg := &config.Config{
+		AppID:     "test-app-id",
+		AppSecret: "test-app-secret",
+		Cache:     cache.NewMemory(),
+	}
+	cacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", credential.CacheKeyOfficialAccountPrefix, cfg.AppID)
+	if err := cfg.Cache.Delete(cacheKey); err != nil {
+		t.Fatalf("清除缓存失败: %v", err)
+	}
+	t.Log("清除 jsapi_ticket 的缓存:", cacheKey)
+
+	ctxHandle := &context.Context{Config: cfg, AccessTokenHandle: &mockAccessTokenHandle{}}
+	jsInstance := NewJs(ctxHandle, cfg.AppID)
+	jsInstance.SetJsTicketHandle(credential.NewDefaultJsTicket(cfg.AppID, credential.CacheKeyOfficialAccountPrefix, cfg.Cache))
+
+	originalClient := util.DefaultHTTPClient
+	util.DefaultHTTPClient = &http.Client{
+		Transport: &contextCheckingRoundTripper{originalCtx: ctx, t: t, key: key, expectedVal: val},
+	}
+	return jsInstance, func() { util.DefaultHTTPClient = originalClient }
+}
+
+// TestGetConfigContext 测试GetConfigContext的上下文传递和取消行为。
+func TestGetConfigContext(t *testing.T) {
+	t.Run("ContextPassing", func(t *testing.T) {
+		ctxKey := contextKey("testKey111") // 使用自定义类型 contextKey
+		ctxValue := "testValue222"
+		ctx := context2.WithValue(context2.Background(), ctxKey, ctxValue)
+		t.Logf("创建的测试上下文: %p, 添加的键值对: %v=%v\n", ctx, ctxKey, ctxValue)
+
+		jsInstance, cleanup := setupJsInstance(t, ctx, ctxKey, ctxValue)
+		defer cleanup()
+		t.Log("调用 GetConfigContext")
+		config2, err := jsInstance.GetConfigContext(ctx, "https://www.baidu.com", "test-app-id")
+		if err != nil {
+			t.Fatalf("GetConfigContext 失败: %v", err)
+		}
+		if config2.AppID != "test-app-id" {
+			t.Errorf("预期 AppID 为 %s,实际为 %s", "test-app-id", config2.AppID)
+		}
+	})
+
+	t.Run("ContextCancellation", func(t *testing.T) {
+		ctx, cancel := context2.WithCancel(context2.Background())
+		defer cancel()
+
+		jsInstance, cleanup := setupJsInstance(t, ctx, nil, nil)
+		defer cleanup()
+
+		cancel()
+		t.Log("调用 GetConfigContext(已取消上下文)")
+		_, err := jsInstance.GetConfigContext(ctx, "https://www.baidu.com", "test-app-id")
+		if err == nil {
+			t.Error("预期上下文取消错误,但 GetConfigContext 未返回错误")
+		} else if !errors.Is(err, context2.Canceled) {
+			t.Errorf("预期错误为 context.Canceled,实际为: %v", err)
+		}
+	})
+}