Forráskód Böngészése

GetAccessToken支持Context (#618)

okhowang 3 éve
szülő
commit
5380d5bee7

+ 8 - 0
credential/access_token.go

@@ -1,6 +1,14 @@
 package credential
 
+import "context"
+
 // AccessTokenHandle AccessToken 接口
 type AccessTokenHandle interface {
 	GetAccessToken() (accessToken string, err error)
 }
+
+// AccessTokenContextHandle AccessToken 接口
+type AccessTokenContextHandle interface {
+	AccessTokenHandle
+	GetAccessTokenContext(ctx context.Context) (accessToken string, err error)
+}

+ 21 - 5
credential/default_access_token.go

@@ -1,6 +1,7 @@
 package credential
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"sync"
@@ -33,7 +34,7 @@ type DefaultAccessToken struct {
 }
 
 // NewDefaultAccessToken new DefaultAccessToken
-func NewDefaultAccessToken(appID, appSecret, cacheKeyPrefix string, cache cache.Cache) AccessTokenHandle {
+func NewDefaultAccessToken(appID, appSecret, cacheKeyPrefix string, cache cache.Cache) AccessTokenContextHandle {
 	if cache == nil {
 		panic("cache is ineed")
 	}
@@ -56,6 +57,11 @@ type ResAccessToken struct {
 
 // GetAccessToken 获取access_token,先从cache中获取,没有则从服务端获取
 func (ak *DefaultAccessToken) GetAccessToken() (accessToken string, err error) {
+	return ak.GetAccessTokenContext(context.Background())
+}
+
+// GetAccessTokenContext 获取access_token,先从cache中获取,没有则从服务端获取
+func (ak *DefaultAccessToken) GetAccessTokenContext(ctx context.Context) (accessToken string, err error) {
 	// 先从cache中取
 	accessTokenCacheKey := fmt.Sprintf("%s_access_token_%s", ak.cacheKeyPrefix, ak.appID)
 	if val := ak.cache.Get(accessTokenCacheKey); val != nil {
@@ -73,7 +79,7 @@ func (ak *DefaultAccessToken) GetAccessToken() (accessToken string, err error) {
 
 	// cache失效,从微信服务器获取
 	var resAccessToken ResAccessToken
-	resAccessToken, err = GetTokenFromServer(fmt.Sprintf(accessTokenURL, ak.appID, ak.appSecret))
+	resAccessToken, err = GetTokenFromServerContext(ctx, fmt.Sprintf(accessTokenURL, ak.appID, ak.appSecret))
 	if err != nil {
 		return
 	}
@@ -97,7 +103,7 @@ type WorkAccessToken struct {
 }
 
 // NewWorkAccessToken new WorkAccessToken
-func NewWorkAccessToken(corpID, corpSecret, cacheKeyPrefix string, cache cache.Cache) AccessTokenHandle {
+func NewWorkAccessToken(corpID, corpSecret, cacheKeyPrefix string, cache cache.Cache) AccessTokenContextHandle {
 	if cache == nil {
 		panic("cache the not exist")
 	}
@@ -112,6 +118,11 @@ func NewWorkAccessToken(corpID, corpSecret, cacheKeyPrefix string, cache cache.C
 
 // GetAccessToken 企业微信获取access_token,先从cache中获取,没有则从服务端获取
 func (ak *WorkAccessToken) GetAccessToken() (accessToken string, err error) {
+	return ak.GetAccessTokenContext(context.Background())
+}
+
+// GetAccessTokenContext 企业微信获取access_token,先从cache中获取,没有则从服务端获取
+func (ak *WorkAccessToken) GetAccessTokenContext(ctx context.Context) (accessToken string, err error) {
 	// 加上lock,是为了防止在并发获取token时,cache刚好失效,导致从微信服务器上获取到不同token
 	ak.accessTokenLock.Lock()
 	defer ak.accessTokenLock.Unlock()
@@ -124,7 +135,7 @@ func (ak *WorkAccessToken) GetAccessToken() (accessToken string, err error) {
 
 	// cache失效,从微信服务器获取
 	var resAccessToken ResAccessToken
-	resAccessToken, err = GetTokenFromServer(fmt.Sprintf(workAccessTokenURL, ak.CorpID, ak.CorpSecret))
+	resAccessToken, err = GetTokenFromServerContext(ctx, fmt.Sprintf(workAccessTokenURL, ak.CorpID, ak.CorpSecret))
 	if err != nil {
 		return
 	}
@@ -140,8 +151,13 @@ func (ak *WorkAccessToken) GetAccessToken() (accessToken string, err error) {
 
 // GetTokenFromServer 强制从微信服务器获取token
 func GetTokenFromServer(url string) (resAccessToken ResAccessToken, err error) {
+	return GetTokenFromServerContext(context.Background(), url)
+}
+
+// GetTokenFromServerContext 强制从微信服务器获取token
+func GetTokenFromServerContext(ctx context.Context, url string) (resAccessToken ResAccessToken, err error) {
 	var body []byte
-	body, err = util.HTTPGet(url)
+	body, err = util.HTTPGetContext(ctx, url)
 	if err != nil {
 		return
 	}

+ 0 - 5
go.sum

@@ -26,12 +26,10 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W
 github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0=
 github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI=
 github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
-github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw=
 github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
 github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
 github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
 github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
-github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
 github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
 github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
@@ -48,7 +46,6 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108
 github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0=
 github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
 github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
-github.com/onsi/ginkgo/v2 v2.0.0 h1:CcuG/HvWNkkaqCUpJifQY8z7qEMBJya6aLPx6ftGyjQ=
 github.com/onsi/ginkgo/v2 v2.0.0/go.mod h1:vw5CSIxN1JObi/U8gcbwft7ZxR2dgaR70JSE3/PpL4c=
 github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
 github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
@@ -118,7 +115,6 @@ golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4f
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
 google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
@@ -127,7 +123,6 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE
 google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo=
 google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
-google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk=
 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

+ 9 - 0
officialaccount/officialaccount.go

@@ -1,6 +1,7 @@
 package officialaccount
 
 import (
+	stdcontext "context"
 	"net/http"
 
 	"github.com/silenceper/wechat/v2/officialaccount/draft"
@@ -94,6 +95,14 @@ func (officialAccount *OfficialAccount) GetAccessToken() (string, error) {
 	return officialAccount.ctx.GetAccessToken()
 }
 
+// GetAccessTokenContext 获取access_token
+func (officialAccount *OfficialAccount) GetAccessTokenContext(ctx stdcontext.Context) (string, error) {
+	if c, ok := officialAccount.ctx.AccessTokenHandle.(credential.AccessTokenContextHandle); ok {
+		return c.GetAccessTokenContext(ctx)
+	}
+	return officialAccount.ctx.GetAccessToken()
+}
+
 // GetOauth oauth2网页授权
 func (officialAccount *OfficialAccount) GetOauth() *oauth.Oauth {
 	if officialAccount.oauth == nil {