yaotian hace 8 años
padre
commit
88be5d3ed0
Se han modificado 4 ficheros con 263 adiciones y 0 borrados
  1. 109 0
      mch/pay/pay.go
  2. 144 0
      mch/pay/tools.go
  3. 4 0
      server/context/context.go
  4. 6 0
      wechat.go

+ 109 - 0
mch/pay/pay.go

@@ -1,10 +1,119 @@
 package pay
 
 import (
+	"bytes"
+	"errors"
+	"fmt"
+	"io/ioutil"
+	"net/http"
+
+	"github.com/astaxie/beego"
 	"github.com/yaotian/gowechat/server/context"
 )
 
+const (
+	ReturnCodeSuccess = "SUCCESS"
+	ReturnCodeFail    = "FAIL"
+)
+
+const (
+	ResultCodeSuccess = "SUCCESS"
+	ResultCodeFail    = "FAIL"
+)
+
+type Error struct {
+	XMLName    struct{} `xml:"xml"                  json:"-"`
+	ReturnCode string   `xml:"return_code"          json:"return_code"`
+	ReturnMsg  string   `xml:"return_msg,omitempty" json:"return_msg,omitempty"`
+}
+
+func (e *Error) Error() string {
+	return fmt.Sprintf("return_code: %q, return_msg: %q", e.ReturnCode, e.ReturnMsg)
+}
+
 //Pay pay
 type Pay struct {
 	*context.Context
 }
+
+//PostXML postXML
+func (c *Pay) PostXML(url string, req map[string]string, needSSL bool) (resp map[string]string, err error) {
+	bodyBuf := textBufferPool.Get().(*bytes.Buffer)
+	bodyBuf.Reset()
+	defer textBufferPool.Put(bodyBuf)
+
+	if err = FormatMapToXML(bodyBuf, req); err != nil {
+		return
+	}
+
+	//需要ssl,就需要ssl client
+	client := c.HTTPClient
+	if needSSL {
+		client = c.SHTTPClient
+	}
+
+	httpResp, err := client.Post(url, "text/xml; charset=utf-8", bodyBuf)
+	if err != nil {
+		return resp, err
+	}
+	defer httpResp.Body.Close()
+
+	if httpResp.StatusCode != http.StatusOK {
+		err = fmt.Errorf("http.Status: %s", httpResp.Status)
+		return
+	}
+
+	respBody, err := ioutil.ReadAll(httpResp.Body)
+	if err != nil {
+		return resp, err
+	}
+
+	if resp, err = ParseXMLToMap(bytes.NewReader(respBody)); err != nil {
+		return
+	}
+
+	beego.Debug(resp)
+
+	// 判断协议状态
+	ReturnCode, ok := resp["return_code"]
+	if !ok {
+		err = errors.New("no return_code parameter")
+		return
+	}
+	if ReturnCode != ReturnCodeSuccess {
+		err = &Error{
+			ReturnCode: ReturnCode,
+			ReturnMsg:  resp["return_msg"],
+		}
+		return
+	}
+
+	// 安全考虑, 做下验证
+	mchId, ok := resp["mch_id"]
+	if ok && mchId != c.MchID {
+		err = fmt.Errorf("mch_id mismatch, have: %q, want: %q", mchId, c.MchID)
+		return
+	}
+
+	//发送红包的情况,不需要验证这些,因为有的信息没有
+	if !needSSL {
+		appId, ok := resp["appid"]
+		if ok && appId != c.AppID {
+			err = fmt.Errorf("appid mismatch, have: %q, want: %q", appId, c.AppID)
+			return
+		}
+
+		// 认证签名
+		signature1, ok := resp["sign"]
+		if !ok {
+			err = errors.New("no sign parameter")
+			return
+		}
+		signature2 := Sign(resp, c.MchAPIKey, nil)
+		if signature1 != signature2 {
+			err = fmt.Errorf("check signature failed, \r\ninput: %q, \r\nlocal: %q", signature1, signature2)
+			return
+		}
+	}
+	return
+}

+ 144 - 0
mch/pay/tools.go

@@ -0,0 +1,144 @@
+package pay
+
+import (
+	"bytes"
+	"crypto/md5"
+	"encoding/hex"
+	"encoding/xml"
+	"errors"
+	"hash"
+	"io"
+	"sort"
+	"sync"
+)
+
+var textBufferPool = sync.Pool{
+	New: func() interface{} {
+		return bytes.NewBuffer(make([]byte, 0, 16<<10)) // 16KB
+	},
+}
+
+// FormatMapToXML marshal map[string]string to xmlWriter with xml format, the root node name is xml.
+//  NOTE: This function assumes the key of m map[string]string are legitimate xml name string
+//  that does not contain the required escape character!
+func FormatMapToXML(xmlWriter io.Writer, m map[string]string) (err error) {
+	if xmlWriter == nil {
+		return errors.New("nil xmlWriter")
+	}
+
+	if _, err = io.WriteString(xmlWriter, "<xml>"); err != nil {
+		return
+	}
+
+	for k, v := range m {
+		if _, err = io.WriteString(xmlWriter, "<"+k+">"); err != nil {
+			return
+		}
+		if err = xml.EscapeText(xmlWriter, []byte(v)); err != nil {
+			return
+		}
+		if _, err = io.WriteString(xmlWriter, "</"+k+">"); err != nil {
+			return
+		}
+	}
+
+	if _, err = io.WriteString(xmlWriter, "</xml>"); err != nil {
+		return
+	}
+	return
+}
+
+// 微信支付签名.
+//  parameters: 待签名的参数集合
+//  apiKey:     API密钥
+//  fn:         func() hash.Hash, 如果 fn == nil 则默认用 md5.New
+func Sign(parameters map[string]string, apiKey string, fn func() hash.Hash) string {
+	ks := make([]string, 0, len(parameters))
+	for k := range parameters {
+		if k == "sign" {
+			continue
+		}
+		ks = append(ks, k)
+	}
+	sort.Strings(ks)
+
+	if fn == nil {
+		fn = md5.New
+	}
+	h := fn()
+
+	buf := make([]byte, 256)
+	for _, k := range ks {
+		v := parameters[k]
+		if v == "" {
+			continue
+		}
+
+		buf = buf[:0]
+		buf = append(buf, k...)
+		buf = append(buf, '=')
+		buf = append(buf, v...)
+		buf = append(buf, '&')
+		h.Write(buf)
+	}
+	buf = buf[:0]
+	buf = append(buf, "key="...)
+	buf = append(buf, apiKey...)
+	h.Write(buf)
+
+	signature := make([]byte, h.Size()*2)
+	hex.Encode(signature, h.Sum(nil))
+	return string(bytes.ToUpper(signature))
+}
+
+// ParseXMLToMap parses xml reading from xmlReader and returns the first-level sub-node key-value set,
+// if the first-level sub-node contains child nodes, skip it.
+func ParseXMLToMap(xmlReader io.Reader) (m map[string]string, err error) {
+	if xmlReader == nil {
+		err = errors.New("nil xmlReader")
+		return
+	}
+
+	m = make(map[string]string)
+	var (
+		d     = xml.NewDecoder(xmlReader)
+		tk    xml.Token
+		depth = 0 // current xml.Token depth
+		key   string
+		value bytes.Buffer
+	)
+	for {
+		tk, err = d.Token()
+		if err != nil {
+			if err == io.EOF {
+				err = nil
+			}
+			return
+		}
+
+		switch v := tk.(type) {
+		case xml.StartElement:
+			depth++
+			switch depth {
+			case 2:
+				key = v.Name.Local
+				value.Reset()
+			case 3:
+				if err = d.Skip(); err != nil {
+					return
+				}
+				depth--
+				key = "" // key == "" indicates that the node with depth==2 has children
+			}
+		case xml.CharData:
+			if depth == 2 && key != "" {
+				value.Write(v)
+			}
+		case xml.EndElement:
+			if depth == 2 && key != "" {
+				m[key] = value.String()
+			}
+			depth--
+		}
+	}
+}

+ 4 - 0
server/context/context.go

@@ -27,6 +27,10 @@ type Context struct {
 
 	HTTPClient  *http.Client
 	SHTTPClient *http.Client //SSL client
+
+	//商户平台APIKey
+	MchAPIKey string
+	MchID     string
 }
 
 // Query returns the keyed url query value if it exists

+ 6 - 0
wechat.go

@@ -29,11 +29,14 @@ type Config struct {
 	EncodingAESKey string
 	Cache          cache.Cache
 
+	//mch商户平台需要的变量
 	//证书
 	SslCertFilePath string //证书文件的路径
 	SslKeyFilePath  string
 	SslCertContent  string //证书的内容
 	SslKeyContent   string
+	MchID           string
+	MchAPIKey       string //商户平台设置的api key
 }
 
 // NewWechat init
@@ -68,6 +71,9 @@ func initContext(cfg *Config, context *context.Context) {
 			context.SHTTPClient = client
 		}
 	}
+
+	context.MchAPIKey = cfg.MchAPIKey
+	context.MchID = cfg.MchID
 }
 
 // GetServer 消息管理