js_test.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. // 验证 js.GetConfigContext 是否能正确传递上下文到 HTTP 请求,确保上下文正确传播,防止在获取 JSSDK 配置时发生协程泄露。
  2. package js
  3. import (
  4. "bytes"
  5. context2 "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "testing"
  11. "github.com/silenceper/wechat/v2/cache"
  12. "github.com/silenceper/wechat/v2/credential"
  13. "github.com/silenceper/wechat/v2/officialaccount/config"
  14. "github.com/silenceper/wechat/v2/officialaccount/context"
  15. "github.com/silenceper/wechat/v2/util"
  16. )
  17. // mockAccessTokenHandle 模拟 AccessTokenHandle
  18. type mockAccessTokenHandle struct{}
  19. func (m *mockAccessTokenHandle) GetAccessToken() (string, error) {
  20. return "mock-access-token", nil
  21. }
  22. func (m *mockAccessTokenHandle) GetAccessTokenContext(_ context2.Context) (string, error) {
  23. return "mock-access-token", nil
  24. }
  25. // contextCheckingRoundTripper 自定义 RoundTripper 用于检查 context
  26. type contextCheckingRoundTripper struct {
  27. originalCtx context2.Context
  28. t *testing.T
  29. key interface{}
  30. expectedVal interface{}
  31. }
  32. func (rt *contextCheckingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  33. // 获取请求中的 context
  34. reqCtx := req.Context()
  35. // 打印 context 比较结果
  36. rt.t.Logf("比较上下文的内存地址:\n")
  37. if reqCtx == rt.originalCtx {
  38. rt.t.Logf("上下文具有相同的内存地址。原始上下文: %p, 请求上下文: %p\n", rt.originalCtx, reqCtx)
  39. } else {
  40. rt.t.Logf("上下文具有不同的内存地址。原始上下文: %p, 请求上下文: %p\n", rt.originalCtx, reqCtx)
  41. }
  42. // 检查 context 中的键值对
  43. if rt.key != nil {
  44. value := reqCtx.Value(rt.key)
  45. rt.t.Logf("检查请求上下文中的键 %v:\n", rt.key)
  46. if value != rt.expectedVal {
  47. rt.t.Errorf("上下文键 %v 的值不匹配: 预期 %v, 实际 %v\n", rt.key, rt.expectedVal, value)
  48. } else {
  49. rt.t.Logf("上下文键 %v 的值匹配: 预期 %v, 实际 %v\n", rt.key, rt.expectedVal, value)
  50. }
  51. }
  52. // 检查上下文是否已取消
  53. select {
  54. case <-reqCtx.Done():
  55. return nil, reqCtx.Err() // 返回上下文取消错误
  56. default:
  57. // 返回模拟的 HTTP 响应,包含有效的 JSON
  58. responseBody := `{"ticket":"mock-ticket","expires_in":7200}`
  59. response := &http.Response{
  60. Status: "200 OK",
  61. StatusCode: http.StatusOK,
  62. Proto: "HTTP/1.1",
  63. ProtoMajor: 1,
  64. ProtoMinor: 1,
  65. Body: io.NopCloser(bytes.NewReader([]byte(responseBody))),
  66. ContentLength: int64(len(responseBody)),
  67. Header: make(http.Header),
  68. }
  69. response.Header.Set("Content-Type", "application/json")
  70. return response, nil
  71. }
  72. }
  73. // contextKey 定义自定义上下文键类型,避免使用内置 string 类型
  74. type contextKey string
  75. // setupJsInstance 初始化 Js 实例和 HTTP 客户端
  76. func setupJsInstance(t *testing.T, ctx context2.Context, key, val interface{}) (*Js, func()) {
  77. cfg := &config.Config{
  78. AppID: "test-app-id",
  79. AppSecret: "test-app-secret",
  80. Cache: cache.NewMemory(),
  81. }
  82. cacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", credential.CacheKeyOfficialAccountPrefix, cfg.AppID)
  83. if err := cfg.Cache.Delete(cacheKey); err != nil {
  84. t.Fatalf("清除缓存失败: %v", err)
  85. }
  86. t.Log("清除 jsapi_ticket 的缓存:", cacheKey)
  87. ctxHandle := &context.Context{Config: cfg, AccessTokenHandle: &mockAccessTokenHandle{}}
  88. jsInstance := NewJs(ctxHandle, cfg.AppID)
  89. jsInstance.SetJsTicketHandle(credential.NewDefaultJsTicket(cfg.AppID, credential.CacheKeyOfficialAccountPrefix, cfg.Cache))
  90. originalClient := util.DefaultHTTPClient
  91. util.DefaultHTTPClient = &http.Client{
  92. Transport: &contextCheckingRoundTripper{originalCtx: ctx, t: t, key: key, expectedVal: val},
  93. }
  94. return jsInstance, func() { util.DefaultHTTPClient = originalClient }
  95. }
  96. // TestGetConfigContext 测试GetConfigContext的上下文传递和取消行为。
  97. func TestGetConfigContext(t *testing.T) {
  98. t.Run("ContextPassing", func(t *testing.T) {
  99. ctxKey := contextKey("testKey111") // 使用自定义类型 contextKey
  100. ctxValue := "testValue222"
  101. ctx := context2.WithValue(context2.Background(), ctxKey, ctxValue)
  102. t.Logf("创建的测试上下文: %p, 添加的键值对: %v=%v\n", ctx, ctxKey, ctxValue)
  103. jsInstance, cleanup := setupJsInstance(t, ctx, ctxKey, ctxValue)
  104. defer cleanup()
  105. t.Log("调用 GetConfigContext")
  106. config2, err := jsInstance.GetConfigContext(ctx, "https://www.baidu.com", "test-app-id")
  107. if err != nil {
  108. t.Fatalf("GetConfigContext 失败: %v", err)
  109. }
  110. if config2.AppID != "test-app-id" {
  111. t.Errorf("预期 AppID 为 %s,实际为 %s", "test-app-id", config2.AppID)
  112. }
  113. })
  114. t.Run("ContextCancellation", func(t *testing.T) {
  115. ctx, cancel := context2.WithCancel(context2.Background())
  116. defer cancel()
  117. jsInstance, cleanup := setupJsInstance(t, ctx, nil, nil)
  118. defer cleanup()
  119. cancel()
  120. t.Log("调用 GetConfigContext(已取消上下文)")
  121. _, err := jsInstance.GetConfigContext(ctx, "https://www.baidu.com", "test-app-id")
  122. if err == nil {
  123. t.Error("预期上下文取消错误,但 GetConfigContext 未返回错误")
  124. } else if !errors.Is(err, context2.Canceled) {
  125. t.Errorf("预期错误为 context.Canceled,实际为: %v", err)
  126. }
  127. })
  128. }