ソースを参照

将获取AccessToken和jsTicket单独抽象为interface

silenceper 6 年 前
コミット
b599e93c5b

+ 6 - 0
credential/access_token.go

@@ -0,0 +1,6 @@
+package credential
+
+//AccessTokenHandle AccessToken 接口
+type AccessTokenHandle interface {
+	GetAccessToken() (accessToken string, err error)
+}

+ 97 - 0
credential/default_access_token.go

@@ -0,0 +1,97 @@
+package credential
+
+import (
+	"encoding/json"
+	"fmt"
+	"sync"
+	"time"
+
+	"github.com/silenceper/wechat/v2/cache"
+	"github.com/silenceper/wechat/v2/util"
+)
+
+const (
+	//AccessTokenURL 获取access_token的接口
+	accessTokenURL = "https://api.weixin.qq.com/cgi-bin/token"
+	//CacheKeyOfficialAccountPrefix 微信公众号cache key前缀
+	CacheKeyOfficialAccountPrefix = "gowechat_officialaccount_"
+)
+
+//DefaultAccessToken 默认AccessToken 获取
+type DefaultAccessToken struct {
+	appID           string
+	appSecret       string
+	cacheKeyPrefix  string
+	cache           cache.Cache
+	accessTokenLock *sync.Mutex
+}
+
+//NewDefaultAccessToken new DefaultAccessToken
+func NewDefaultAccessToken(appID, appSecret, cacheKeyPrefix string, cache cache.Cache) AccessTokenHandle {
+	if cache == nil {
+		panic("cache is ineed")
+	}
+	return &DefaultAccessToken{
+		appID:           appID,
+		appSecret:       appSecret,
+		cache:           cache,
+		cacheKeyPrefix:  cacheKeyPrefix,
+		accessTokenLock: new(sync.Mutex),
+	}
+}
+
+//ResAccessToken struct
+type ResAccessToken struct {
+	util.CommonError
+
+	AccessToken string `json:"access_token"`
+	ExpiresIn   int64  `json:"expires_in"`
+}
+
+//GetAccessToken 获取access_token,先从cache中获取,没有则从服务端获取
+func (ak *DefaultAccessToken) GetAccessToken() (accessToken string, err error) {
+	//加上lock,是为了防止在并发获取token时,cache刚好失效,导致从微信服务器上获取到不同token
+	ak.accessTokenLock.Lock()
+	defer ak.accessTokenLock.Unlock()
+
+	accessTokenCacheKey := fmt.Sprintf("%s_access_token_%s", ak.cacheKeyPrefix, ak.appID)
+	val := ak.cache.Get(accessTokenCacheKey)
+	if val != nil {
+		accessToken = val.(string)
+		return
+	}
+
+	//cache失效,从微信服务器获取
+	var resAccessToken ResAccessToken
+	resAccessToken, err = GetTokenFromServer(ak.appID, ak.appSecret)
+	if err != nil {
+		return
+	}
+
+	expires := resAccessToken.ExpiresIn - 1500
+	err = ak.cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(expires)*time.Second)
+	if err != nil {
+		return
+	}
+	accessToken = resAccessToken.AccessToken
+	return
+}
+
+//GetTokenFromServer 强制从微信服务器获取token
+func GetTokenFromServer(appID, appSecret string) (resAccessToken ResAccessToken, err error) {
+	url := fmt.Sprintf("%s?grant_type=client_credential&appid=%s&secret=%s", accessTokenURL, appID, appSecret)
+	var body []byte
+	body, err = util.HTTPGet(url)
+	if err != nil {
+		return
+	}
+	err = json.Unmarshal(body, &resAccessToken)
+	if err != nil {
+		return
+	}
+	if resAccessToken.ErrMsg != "" {
+		err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrCode, resAccessToken.ErrMsg)
+		return
+	}
+	return
+}

+ 80 - 0
credential/default_js_ticket.go

@@ -0,0 +1,80 @@
+package credential
+
+import (
+	"encoding/json"
+	"fmt"
+	"sync"
+	"time"
+
+	"github.com/silenceper/wechat/v2/cache"
+	"github.com/silenceper/wechat/v2/util"
+)
+
+//获取ticket的url
+const getTicketURL = "https://api.weixin.qq.com/cgi-bin/ticket/getticket?access_token=%s&type=jsapi"
+
+//DefaultJsTicket 默认获取js ticket方法
+type DefaultJsTicket struct {
+	appID          string
+	cacheKeyPrefix string
+	cache          cache.Cache
+	//jsAPITicket 读写锁 同一个AppID一个
+	jsAPITicketLock *sync.Mutex
+}
+
+//NewDefaultJsTicket new
+func NewDefaultJsTicket(appID string, cacheKeyPrefix string, cache cache.Cache) JsTicketHandle {
+	return &DefaultJsTicket{
+		appID:           appID,
+		cache:           cache,
+		cacheKeyPrefix:  cacheKeyPrefix,
+		jsAPITicketLock: new(sync.Mutex),
+	}
+}
+
+// ResTicket 请求jsapi_tikcet返回结果
+type ResTicket struct {
+	util.CommonError
+
+	Ticket    string `json:"ticket"`
+	ExpiresIn int64  `json:"expires_in"`
+}
+
+//GetTicket 获取jsapi_ticket
+func (js *DefaultJsTicket) GetTicket(accessToken string) (ticketStr string, err error) {
+	js.jsAPITicketLock.Lock()
+	defer js.jsAPITicketLock.Unlock()
+
+	//先从cache中取
+	jsAPITicketCacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", js.cacheKeyPrefix, js.appID)
+	val := js.cache.Get(jsAPITicketCacheKey)
+	if val != nil {
+		ticketStr = val.(string)
+		return
+	}
+	var ticket ResTicket
+	ticket, err = GetTicketFromServer(accessToken)
+	if err != nil {
+		return
+	}
+	expires := ticket.ExpiresIn - 1500
+	err = js.cache.Set(jsAPITicketCacheKey, ticket.Ticket, time.Duration(expires)*time.Second)
+	ticketStr = ticket.Ticket
+	return
+}
+
+//GetTicketFromServer 从服务器中获取ticket
+func GetTicketFromServer(accessToken string) (ticket ResTicket, err error) {
+	var response []byte
+	url := fmt.Sprintf(getTicketURL, accessToken)
+	response, err = util.HTTPGet(url)
+	err = json.Unmarshal(response, &ticket)
+	if err != nil {
+		return
+	}
+	if ticket.ErrCode != 0 {
+		err = fmt.Errorf("getTicket Error : errcode=%d , errmsg=%s", ticket.ErrCode, ticket.ErrMsg)
+		return
+	}
+	return
+}

+ 7 - 0
credential/js_ticket.go

@@ -0,0 +1,7 @@
+package credential
+
+//JsTicketHandle js ticket获取
+type JsTicketHandle interface {
+	//GetTicket 获取ticket
+	GetTicket(accessToken string) (ticket string, err error)
+}

+ 0 - 87
officialaccount/context/access_token.go

@@ -1,87 +0,0 @@
-package context
-
-import (
-	"encoding/json"
-	"fmt"
-	"sync"
-	"time"
-
-	"github.com/silenceper/wechat/v2/util"
-)
-
-const (
-	//AccessTokenURL 获取access_token的接口
-	AccessTokenURL = "https://api.weixin.qq.com/cgi-bin/token"
-	//CacheKeyPrefix 微信公众号cache key前缀
-	CacheKeyPrefix = "gowechat_officialaccount_"
-)
-
-//ResAccessToken struct
-type ResAccessToken struct {
-	util.CommonError
-
-	AccessToken string `json:"access_token"`
-	ExpiresIn   int64  `json:"expires_in"`
-}
-
-//GetAccessTokenFunc 获取 access token 的函数签名
-type GetAccessTokenFunc func(ctx *Context) (accessToken string, err error)
-
-//SetAccessTokenLock 设置读写锁(一个appID一个读写锁)
-func (ctx *Context) SetAccessTokenLock(l *sync.RWMutex) {
-	ctx.accessTokenLock = l
-}
-
-//SetGetAccessTokenFunc 设置自定义获取accessToken的方式, 需要自己实现缓存
-func (ctx *Context) SetGetAccessTokenFunc(f GetAccessTokenFunc) {
-	ctx.accessTokenFunc = f
-}
-
-//GetAccessToken 获取access_token
-func (ctx *Context) GetAccessToken() (accessToken string, err error) {
-	ctx.accessTokenLock.Lock()
-	defer ctx.accessTokenLock.Unlock()
-
-	if ctx.accessTokenFunc != nil {
-		return ctx.accessTokenFunc(ctx)
-	}
-	accessTokenCacheKey := fmt.Sprintf("%s_access_token_%s", CacheKeyPrefix, ctx.AppID)
-	val := ctx.Cache.Get(accessTokenCacheKey)
-	if val != nil {
-		accessToken = val.(string)
-		return
-	}
-
-	//从微信服务器获取
-	var resAccessToken ResAccessToken
-	resAccessToken, err = ctx.GetAccessTokenFromServer()
-	if err != nil {
-		return
-	}
-
-	accessToken = resAccessToken.AccessToken
-	return
-}
-
-//GetAccessTokenFromServer 强制从微信服务器获取token
-func (ctx *Context) GetAccessTokenFromServer() (resAccessToken ResAccessToken, err error) {
-	url := fmt.Sprintf("%s?grant_type=client_credential&appid=%s&secret=%s", AccessTokenURL, ctx.AppID, ctx.AppSecret)
-	var body []byte
-	body, err = util.HTTPGet(url)
-	if err != nil {
-		return
-	}
-	err = json.Unmarshal(body, &resAccessToken)
-	if err != nil {
-		return
-	}
-	if resAccessToken.ErrMsg != "" {
-		err = fmt.Errorf("get access_token error : errcode=%v , errormsg=%v", resAccessToken.ErrCode, resAccessToken.ErrMsg)
-		return
-	}
-
-	accessTokenCacheKey := fmt.Sprintf("%s_access_token_%s", CacheKeyPrefix, ctx.AppID)
-	expires := resAccessToken.ExpiresIn - 1500
-	err = ctx.Cache.Set(accessTokenCacheKey, resAccessToken.AccessToken, time.Duration(expires)*time.Second)
-	return
-}

+ 2 - 21
officialaccount/context/context.go

@@ -1,31 +1,12 @@
 package context
 
 import (
-	"sync"
-
+	"github.com/silenceper/wechat/v2/credential"
 	"github.com/silenceper/wechat/v2/officialaccount/config"
 )
 
 // Context struct
 type Context struct {
 	*config.Config
-
-	//accessTokenLock 读写锁 同一个AppID一个
-	accessTokenLock *sync.RWMutex
-
-	//jsAPITicket 读写锁 同一个AppID一个
-	jsAPITicketLock *sync.RWMutex
-
-	//accessTokenFunc 自定义获取 access token 的方法
-	accessTokenFunc GetAccessTokenFunc
-}
-
-// SetJsAPITicketLock 设置jsAPITicket的lock
-func (ctx *Context) SetJsAPITicketLock(lock *sync.RWMutex) {
-	ctx.jsAPITicketLock = lock
-}
-
-// GetJsAPITicketLock 获取jsAPITicket 的lock
-func (ctx *Context) GetJsAPITicketLock() *sync.RWMutex {
-	return ctx.jsAPITicketLock
+	credential.AccessTokenHandle
 }

+ 15 - 60
officialaccount/js/js.go

@@ -1,19 +1,17 @@
 package js
 
 import (
-	"encoding/json"
 	"fmt"
-	"time"
 
+	"github.com/silenceper/wechat/v2/credential"
 	"github.com/silenceper/wechat/v2/officialaccount/context"
 	"github.com/silenceper/wechat/v2/util"
 )
 
-const getTicketURL = "https://api.weixin.qq.com/cgi-bin/ticket/getticket?access_token=%s&type=jsapi"
-
 // Js struct
 type Js struct {
 	*context.Context
+	credential.JsTicketHandle
 }
 
 // Config 返回给用户jssdk配置信息
@@ -24,27 +22,31 @@ type Config struct {
 	Signature string `json:"signature"`
 }
 
-// resTicket 请求jsapi_tikcet返回结果
-type resTicket struct {
-	util.CommonError
-
-	Ticket    string `json:"ticket"`
-	ExpiresIn int64  `json:"expires_in"`
-}
-
 //NewJs init
 func NewJs(context *context.Context) *Js {
 	js := new(Js)
 	js.Context = context
+	jsTicketHandle := credential.NewDefaultJsTicket(context.AppID, credential.CacheKeyOfficialAccountPrefix, context.Cache)
+	js.SetJsTicketHandle(jsTicketHandle)
 	return js
 }
 
+//SetJsTicketHandle 自定义js ticket取值方式
+func (js *Js) SetJsTicketHandle(ticketHandle credential.JsTicketHandle) {
+	js.JsTicketHandle = ticketHandle
+}
+
 //GetConfig 获取jssdk需要的配置参数
 //uri 为当前网页地址
 func (js *Js) GetConfig(uri string) (config *Config, err error) {
 	config = new(Config)
+	var accessToken string
+	accessToken, err = js.GetAccessToken()
+	if err != nil {
+		return
+	}
 	var ticketStr string
-	ticketStr, err = js.GetTicket()
+	ticketStr, err = js.GetTicket(accessToken)
 	if err != nil {
 		return
 	}
@@ -60,50 +62,3 @@ func (js *Js) GetConfig(uri string) (config *Config, err error) {
 	config.Signature = sigStr
 	return
 }
-
-//GetTicket 获取jsapi_ticket
-func (js *Js) GetTicket() (ticketStr string, err error) {
-	js.GetJsAPITicketLock().Lock()
-	defer js.GetJsAPITicketLock().Unlock()
-
-	//先从cache中取
-	jsAPITicketCacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", context.CacheKeyPrefix, js.AppID)
-	val := js.Cache.Get(jsAPITicketCacheKey)
-	if val != nil {
-		ticketStr = val.(string)
-		return
-	}
-	var ticket resTicket
-	ticket, err = js.getTicketFromServer()
-	if err != nil {
-		return
-	}
-	ticketStr = ticket.Ticket
-	return
-}
-
-//getTicketFromServer 强制从服务器中获取ticket
-func (js *Js) getTicketFromServer() (ticket resTicket, err error) {
-	var accessToken string
-	accessToken, err = js.GetAccessToken()
-	if err != nil {
-		return
-	}
-
-	var response []byte
-	url := fmt.Sprintf(getTicketURL, accessToken)
-	response, err = util.HTTPGet(url)
-	err = json.Unmarshal(response, &ticket)
-	if err != nil {
-		return
-	}
-	if ticket.ErrCode != 0 {
-		err = fmt.Errorf("getTicket Error : errcode=%d , errmsg=%s", ticket.ErrCode, ticket.ErrMsg)
-		return
-	}
-
-	jsAPITicketCacheKey := fmt.Sprintf("%s_jsapi_ticket_%s", context.CacheKeyPrefix, js.AppID)
-	expires := ticket.ExpiresIn - 1500
-	err = js.Cache.Set(jsAPITicketCacheKey, ticket.Ticket, time.Duration(expires)*time.Second)
-	return
-}

+ 10 - 8
officialaccount/officialaccount.go

@@ -2,8 +2,8 @@ package officialaccount
 
 import (
 	"net/http"
-	"sync"
 
+	"github.com/silenceper/wechat/v2/credential"
 	"github.com/silenceper/wechat/v2/officialaccount/basic"
 	"github.com/silenceper/wechat/v2/officialaccount/config"
 	"github.com/silenceper/wechat/v2/officialaccount/context"
@@ -24,15 +24,17 @@ type OfficialAccount struct {
 
 //NewOfficialAccount 实例化公众号API
 func NewOfficialAccount(cfg *config.Config) *OfficialAccount {
-	//if cfg.Cache == nil {
-	//	panic("cache未设置")
-	//}
+	defaultAK := credential.NewDefaultAccessToken(cfg.AppID, cfg.AppSecret, credential.CacheKeyOfficialAccountPrefix, cfg.Cache)
 	ctx := &context.Context{
-		Config: cfg,
+		Config:            cfg,
+		AccessTokenHandle: defaultAK,
 	}
-	ctx.SetAccessTokenLock(new(sync.RWMutex))
-	ctx.SetJsAPITicketLock(new(sync.RWMutex))
-	return &OfficialAccount{ctx}
+	return &OfficialAccount{ctx: ctx}
+}
+
+//SetAccessTokenHandle 自定义access_token获取方式
+func (officialAccount *OfficialAccount) SetAccessTokenHandle(accessTokenHandle credential.AccessTokenHandle) {
+	officialAccount.ctx.AccessTokenHandle = accessTokenHandle
 }
 
 // GetContext get Context

+ 21 - 5
openplatform/officialaccount/officialaccount.go

@@ -1,9 +1,9 @@
 package officialaccount
 
 import (
+	"github.com/silenceper/wechat/v2/credential"
 	"github.com/silenceper/wechat/v2/officialaccount"
 	offConfig "github.com/silenceper/wechat/v2/officialaccount/config"
-	offContext "github.com/silenceper/wechat/v2/officialaccount/context"
 	opContext "github.com/silenceper/wechat/v2/openplatform/context"
 )
 
@@ -25,9 +25,25 @@ func NewOfficialAccount(opCtx *opContext.Context, appID string) *OfficialAccount
 		Cache:          opCtx.Cache,
 	})
 	//设置获取access_token的函数
-	officialAccount.GetContext().SetGetAccessTokenFunc(func(offCtx *offContext.Context) (accessToken string, err error) {
-		// 获取授权方的access_token
-		return opCtx.GetAuthrAccessToken(appID)
-	})
+	officialAccount.SetAccessTokenHandle(NewDefaultAuthrAccessToken(opCtx, appID))
 	return &OfficialAccount{appID: appID, OfficialAccount: officialAccount}
 }
+
+//DefaultAuthrAccessToken 默认获取授权ak的方法
+type DefaultAuthrAccessToken struct {
+	opCtx *opContext.Context
+	appID string
+}
+
+//NewDefaultAuthrAccessToken New
+func NewDefaultAuthrAccessToken(opCtx *opContext.Context, appID string) credential.AccessTokenHandle {
+	return &DefaultAuthrAccessToken{
+		opCtx: opCtx,
+		appID: appID,
+	}
+}
+
+//GetAccessToken 获取ak
+func (ak *DefaultAuthrAccessToken) GetAccessToken() (string, error) {
+	return ak.opCtx.GetAuthrAccessToken(ak.appID)
+}