Bläddra i källkod

fix some bugs (#277)

* fix payRequset xml marshal: root element should be xml

* support HMAC-SHA256 for unifiedorder api

* support HMAC-SHA256 for PaidVerifySign

* fix SignType is nil

* fix code style

* constantize SignType

* add comments

* fix code style
huang wei 6 år sedan
förälder
incheckning
c14c020a3c
6 ändrade filer med 95 tillägg och 93 borttagningar
  1. 2 0
      go.sum
  2. 9 2
      pay/notify/paid.go
  3. 14 18
      pay/order/pay.go
  4. 9 8
      pay/refund/refund.go
  5. 42 12
      util/crypto.go
  6. 19 53
      util/param.go

+ 2 - 0
go.sum

@@ -14,6 +14,7 @@ github.com/konsorten/go-windows-terminal-sequences v1.0.3 h1:CE8S1cTafDpPvMhIxNJ
 github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
 github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4=
 github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/sirupsen/logrus v1.6.0 h1:UBcNElsrwanuuMsnGSlYmtmgbb23qDR5dG+6X6Oo89I=
 github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
@@ -21,6 +22,7 @@ github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng=
 github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
+github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
 github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.0.0-20200510223506-06a226fb4e37 h1:cg5LA/zNPRzIXIWSCxQW10Rvpy94aQh3LT/ShoCpkHw=

+ 9 - 2
pay/notify/paid.go

@@ -85,8 +85,15 @@ func (notify *Notify) PaidVerifySign(notifyRes PaidResult) bool {
 	// STEP3, 在键值对的最后加上key=API_KEY
 	signStrings = signStrings + "key=" + notify.Key
 
-	// STEP4, 进行MD5签名并且将所有字符转为大写.
-	sign := util.MD5Sum(signStrings)
+	// STEP4, 根据SignType计算出签名
+	var signType string
+	if notifyRes.SignType != nil {
+		signType = *notifyRes.SignType
+	}
+	sign, err := util.CalculateSign(signStrings, signType, notify.Key)
+	if err != nil {
+		return false
+	}
 	if sign != *notifyRes.Sign {
 		return false
 	}

+ 14 - 18
pay/order/pay.go

@@ -1,13 +1,8 @@
 package order
 
 import (
-	"crypto/hmac"
-	"crypto/md5"
-	"crypto/sha256"
-	"encoding/hex"
 	"encoding/xml"
 	"errors"
-	"hash"
 	"strconv"
 	"strings"
 	"time"
@@ -96,13 +91,14 @@ type payRequest struct {
 	LimitPay       string `xml:"limit_pay,omitempty"`   //
 	OpenID         string `xml:"openid,omitempty"`      // 用户标识
 	SceneInfo      string `xml:"scene_info,omitempty"`  // 场景信息
+
+	XMLName struct{} `xml:"xml"`
 }
 
 // BridgeConfig get js bridge config
 func (o *Order) BridgeConfig(p *Params) (cfg Config, err error) {
 	var (
 		buffer    strings.Builder
-		h         hash.Hash
 		timestamp = strconv.FormatInt(time.Now().Unix(), 10)
 	)
 	order, err := o.PrePayOrder(p)
@@ -121,14 +117,13 @@ func (o *Order) BridgeConfig(p *Params) (cfg Config, err error) {
 	buffer.WriteString(timestamp)
 	buffer.WriteString("&key=")
 	buffer.WriteString(o.Key)
-	if p.SignType == "MD5" {
-		h = md5.New()
-	} else {
-		h = hmac.New(sha256.New, []byte(o.Key))
+
+	sign, err := util.CalculateSign(buffer.String(), p.SignType, o.Key)
+	if err != nil {
+		return
 	}
-	h.Write([]byte(buffer.String()))
 	// 签名
-	cfg.PaySign = strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
+	cfg.PaySign = sign
 	cfg.NonceStr = order.NonceStr
 	cfg.Timestamp = timestamp
 	cfg.PrePayID = order.PrePayID
@@ -143,13 +138,13 @@ func (o *Order) PrePayOrder(p *Params) (payOrder PreOrder, err error) {
 	notifyURL := o.NotifyURL
 	// 签名类型
 	if p.SignType == "" {
-		p.SignType = "MD5"
+		p.SignType = util.SignTypeMD5
 	}
 	// 通知地址
 	if p.NotifyURL != "" {
 		notifyURL = p.NotifyURL
 	}
-	param := make(map[string]interface{})
+	param := make(map[string]string)
 	param["appid"] = o.AppID
 	param["body"] = p.Body
 	param["mch_id"] = o.MchID
@@ -165,9 +160,10 @@ func (o *Order) PrePayOrder(p *Params) (payOrder PreOrder, err error) {
 	param["goods_tag"] = p.GoodsTag
 	param["notify_url"] = notifyURL
 
-	bizKey := "&key=" + o.Key
-	str := util.OrderParam(param, bizKey)
-	sign := util.MD5Sum(str)
+	sign, err := util.ParamSign(param, o.Key)
+	if err != nil {
+		return
+	}
 	request := payRequest{
 		AppID:          o.AppID,
 		MchID:          o.MchID,
@@ -202,7 +198,7 @@ func (o *Order) PrePayOrder(p *Params) (payOrder PreOrder, err error) {
 		err = errors.New(payOrder.ErrCode + payOrder.ErrCodeDes)
 		return
 	}
-	err = errors.New("[msg : xmlUnmarshalError] [rawReturn : " + string(rawRet) + "] [params : " + str + "] [sign : " + sign + "]")
+	err = errors.New("[msg : xmlUnmarshalError] [rawReturn : " + string(rawRet) + "] [sign : " + sign + "]")
 	return
 }
 

+ 9 - 8
pay/refund/refund.go

@@ -73,7 +73,7 @@ type Response struct {
 //Refund 退款申请
 func (refund *Refund) Refund(p *Params) (rsp Response, err error) {
 	nonceStr := util.RandomStr(32)
-	param := make(map[string]interface{})
+	param := make(map[string]string)
 	param["appid"] = refund.AppID
 	param["mch_id"] = refund.MchID
 	param["nonce_str"] = nonceStr
@@ -81,18 +81,20 @@ func (refund *Refund) Refund(p *Params) (rsp Response, err error) {
 	param["refund_desc"] = p.RefundDesc
 	param["refund_fee"] = p.RefundFee
 	param["total_fee"] = p.TotalFee
-	param["sign_type"] = "MD5"
+	param["sign_type"] = util.SignTypeMD5
 	param["transaction_id"] = p.TransactionID
 
-	bizKey := "&key=" + refund.Key
-	str := util.OrderParam(param, bizKey)
-	sign := util.MD5Sum(str)
+	sign, err := util.ParamSign(param, refund.Key)
+	if err != nil {
+		return
+	}
+
 	request := request{
 		AppID:         refund.AppID,
 		MchID:         refund.MchID,
 		NonceStr:      nonceStr,
 		Sign:          sign,
-		SignType:      "MD5",
+		SignType:      util.SignTypeMD5,
 		TransactionID: p.TransactionID,
 		OutRefundNo:   p.OutRefundNo,
 		TotalFee:      p.TotalFee,
@@ -115,7 +117,6 @@ func (refund *Refund) Refund(p *Params) (rsp Response, err error) {
 		err = fmt.Errorf("refund error, errcode=%s,errmsg=%s", rsp.ErrCode, rsp.ErrCodeDes)
 		return
 	}
-	err = fmt.Errorf("[msg : xmlUnmarshalError] [rawReturn : %s] [params : %s] [sign : %s]",
-		string(rawRet), str, sign)
+	err = fmt.Errorf("[msg : xmlUnmarshalError] [rawReturn : %s] [sign : %s]", string(rawRet), sign)
 	return
 }

+ 42 - 12
util/crypto.go

@@ -1,14 +1,23 @@
 package util
 
 import (
-	"bufio"
-	"bytes"
 	"crypto/aes"
 	"crypto/cipher"
+	"crypto/hmac"
 	"crypto/md5"
+	"crypto/sha256"
 	"encoding/base64"
 	"encoding/hex"
+	"errors"
 	"fmt"
+	"hash"
+	"strings"
+)
+
+// 微信签名算法方式
+const (
+	SignTypeMD5        = `MD5`
+	SignTypeHMACSHA256 = `HMAC-SHA256`
 )
 
 //EncryptMsg 加密消息
@@ -186,14 +195,35 @@ func decodeNetworkByteOrder(orderBytes []byte) (n uint32) {
 		uint32(orderBytes[3])
 }
 
-// MD5Sum 计算 32 位长度的 MD5 sum
-func MD5Sum(txt string) (sum string) {
-	h := md5.New()
-	buf := bufio.NewWriterSize(h, 128)
-	buf.WriteString(txt)
-	buf.Flush()
-	sign := make([]byte, hex.EncodedLen(h.Size()))
-	hex.Encode(sign, h.Sum(nil))
-	sum = string(bytes.ToUpper(sign))
-	return
+// CalculateSign 计算签名
+func CalculateSign(content, signType, key string) (string, error) {
+	var h hash.Hash
+	if signType == SignTypeMD5 {
+		h = md5.New()
+	} else {
+		h = hmac.New(sha256.New, []byte(key))
+	}
+
+	if _, err := h.Write([]byte(content)); err != nil {
+		return ``, err
+	}
+	return strings.ToUpper(hex.EncodeToString(h.Sum(nil))), nil
+}
+
+// ParamSign 计算所传参数的签名
+func ParamSign(p map[string]string, key string) (string, error) {
+	bizKey := "&key=" + key
+	str := OrderParam(p, bizKey)
+
+	var signType string
+	switch p["sign_type"] {
+	case SignTypeMD5, SignTypeHMACSHA256:
+		signType = p["sign_type"]
+	case ``:
+		signType = SignTypeMD5
+	default:
+		return ``, errors.New(`invalid sign_type`)
+	}
+
+	return CalculateSign(str, signType, key)
 }

+ 19 - 53
util/param.go

@@ -3,65 +3,31 @@ package util
 import (
 	"bytes"
 	"sort"
-	"strconv"
 )
 
 // OrderParam order params
-func OrderParam(source interface{}, bizKey string) (returnStr string) {
-	switch v := source.(type) {
-	case map[string]string:
-		keys := make([]string, 0, len(v))
-		for k := range v {
-			if k == "sign" {
-				continue
-			}
-			keys = append(keys, k)
+func OrderParam(p map[string]string, bizKey string) (returnStr string) {
+	keys := make([]string, 0, len(p))
+	for k := range p {
+		if k == "sign" {
+			continue
 		}
-		sort.Strings(keys)
-		var buf bytes.Buffer
-		for _, k := range keys {
-			if v[k] == "" {
-				continue
-			}
-			if buf.Len() > 0 {
-				buf.WriteByte('&')
-			}
-			buf.WriteString(k)
-			buf.WriteByte('=')
-			buf.WriteString(v[k])
-		}
-		buf.WriteString(bizKey)
-		returnStr = buf.String()
-	case map[string]interface{}:
-		keys := make([]string, 0, len(v))
-		for k := range v {
-			if k == "sign" {
-				continue
-			}
-			keys = append(keys, k)
+		keys = append(keys, k)
+	}
+	sort.Strings(keys)
+	var buf bytes.Buffer
+	for _, k := range keys {
+		if p[k] == "" {
+			continue
 		}
-		sort.Strings(keys)
-		var buf bytes.Buffer
-		for _, k := range keys {
-			if v[k] == "" {
-				continue
-			}
-			if buf.Len() > 0 {
-				buf.WriteByte('&')
-			}
-			buf.WriteString(k)
-			buf.WriteByte('=')
-			switch vv := v[k].(type) {
-			case string:
-				buf.WriteString(vv)
-			case int:
-				buf.WriteString(strconv.FormatInt(int64(vv), 10))
-			default:
-				panic("params type not supported")
-			}
+		if buf.Len() > 0 {
+			buf.WriteByte('&')
 		}
-		buf.WriteString(bizKey)
-		returnStr = buf.String()
+		buf.WriteString(k)
+		buf.WriteByte('=')
+		buf.WriteString(p[k])
 	}
+	buf.WriteString(bizKey)
+	returnStr = buf.String()
 	return
 }