crypto.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package util
  2. import (
  3. "crypto/aes"
  4. "crypto/cipher"
  5. "crypto/hmac"
  6. "crypto/md5"
  7. "crypto/sha256"
  8. "encoding/base64"
  9. "encoding/hex"
  10. "errors"
  11. "fmt"
  12. "hash"
  13. "strings"
  14. )
  15. // 微信签名算法方式
  16. const (
  17. SignTypeMD5 = `MD5`
  18. SignTypeHMACSHA256 = `HMAC-SHA256`
  19. )
  20. // EncryptMsg 加密消息
  21. func EncryptMsg(random, rawXMLMsg []byte, appID, aesKey string) (encrtptMsg []byte, err error) {
  22. defer func() {
  23. if e := recover(); e != nil {
  24. err = fmt.Errorf("panic error: err=%v", e)
  25. return
  26. }
  27. }()
  28. var key []byte
  29. key, err = aesKeyDecode(aesKey)
  30. if err != nil {
  31. panic(err)
  32. }
  33. ciphertext := AESEncryptMsg(random, rawXMLMsg, appID, key)
  34. encrtptMsg = []byte(base64.StdEncoding.EncodeToString(ciphertext))
  35. return
  36. }
  37. // AESEncryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
  38. //参考:github.com/chanxuehong/wechat.v2
  39. func AESEncryptMsg(random, rawXMLMsg []byte, appID string, aesKey []byte) (ciphertext []byte) {
  40. const (
  41. BlockSize = 32 // PKCS#7
  42. BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
  43. )
  44. appIDOffset := 20 + len(rawXMLMsg)
  45. contentLen := appIDOffset + len(appID)
  46. amountToPad := BlockSize - contentLen&BlockMask
  47. plaintextLen := contentLen + amountToPad
  48. plaintext := make([]byte, plaintextLen)
  49. // 拼接
  50. copy(plaintext[:16], random)
  51. encodeNetworkByteOrder(plaintext[16:20], uint32(len(rawXMLMsg)))
  52. copy(plaintext[20:], rawXMLMsg)
  53. copy(plaintext[appIDOffset:], appID)
  54. // PKCS#7 补位
  55. for i := contentLen; i < plaintextLen; i++ {
  56. plaintext[i] = byte(amountToPad)
  57. }
  58. // 加密
  59. block, err := aes.NewCipher(aesKey)
  60. if err != nil {
  61. panic(err)
  62. }
  63. mode := cipher.NewCBCEncrypter(block, aesKey[:16])
  64. mode.CryptBlocks(plaintext, plaintext)
  65. ciphertext = plaintext
  66. return
  67. }
  68. // DecryptMsg 消息解密
  69. func DecryptMsg(appID, encryptedMsg, aesKey string) (random, rawMsgXMLBytes []byte, err error) {
  70. defer func() {
  71. if e := recover(); e != nil {
  72. err = fmt.Errorf("panic error: err=%v", e)
  73. return
  74. }
  75. }()
  76. var encryptedMsgBytes, key, getAppIDBytes []byte
  77. encryptedMsgBytes, err = base64.StdEncoding.DecodeString(encryptedMsg)
  78. if err != nil {
  79. return
  80. }
  81. key, err = aesKeyDecode(aesKey)
  82. if err != nil {
  83. panic(err)
  84. }
  85. random, rawMsgXMLBytes, getAppIDBytes, err = AESDecryptMsg(encryptedMsgBytes, key)
  86. if err != nil {
  87. err = fmt.Errorf("消息解密失败,%v", err)
  88. return
  89. }
  90. if appID != string(getAppIDBytes) {
  91. err = fmt.Errorf("消息解密校验APPID失败")
  92. return
  93. }
  94. return
  95. }
  96. func aesKeyDecode(encodedAESKey string) (key []byte, err error) {
  97. if len(encodedAESKey) != 43 {
  98. err = fmt.Errorf("the length of encodedAESKey must be equal to 43")
  99. return
  100. }
  101. key, err = base64.StdEncoding.DecodeString(encodedAESKey + "=")
  102. if err != nil {
  103. return
  104. }
  105. if len(key) != 32 {
  106. err = fmt.Errorf("encodingAESKey invalid")
  107. return
  108. }
  109. return
  110. }
  111. // AESDecryptMsg ciphertext = AES_Encrypt[random(16B) + msg_len(4B) + rawXMLMsg + appId]
  112. //参考:github.com/chanxuehong/wechat.v2
  113. func AESDecryptMsg(ciphertext []byte, aesKey []byte) (random, rawXMLMsg, appID []byte, err error) {
  114. const (
  115. BlockSize = 32 // PKCS#7
  116. BlockMask = BlockSize - 1 // BLOCK_SIZE 为 2^n 时, 可以用 mask 获取针对 BLOCK_SIZE 的余数
  117. )
  118. if len(ciphertext) < BlockSize {
  119. err = fmt.Errorf("the length of ciphertext too short: %d", len(ciphertext))
  120. return
  121. }
  122. if len(ciphertext)&BlockMask != 0 {
  123. err = fmt.Errorf("ciphertext is not a multiple of the block size, the length is %d", len(ciphertext))
  124. return
  125. }
  126. plaintext := make([]byte, len(ciphertext)) // len(plaintext) >= BLOCK_SIZE
  127. // 解密
  128. block, err := aes.NewCipher(aesKey)
  129. if err != nil {
  130. panic(err)
  131. }
  132. mode := cipher.NewCBCDecrypter(block, aesKey[:16])
  133. mode.CryptBlocks(plaintext, ciphertext)
  134. // PKCS#7 去除补位
  135. amountToPad := int(plaintext[len(plaintext)-1])
  136. if amountToPad < 1 || amountToPad > BlockSize {
  137. err = fmt.Errorf("the amount to pad is incorrect: %d", amountToPad)
  138. return
  139. }
  140. plaintext = plaintext[:len(plaintext)-amountToPad]
  141. // 反拼接
  142. // len(plaintext) == 16+4+len(rawXMLMsg)+len(appId)
  143. if len(plaintext) <= 20 {
  144. err = fmt.Errorf("plaintext too short, the length is %d", len(plaintext))
  145. return
  146. }
  147. rawXMLMsgLen := int(decodeNetworkByteOrder(plaintext[16:20]))
  148. if rawXMLMsgLen < 0 {
  149. err = fmt.Errorf("incorrect msg length: %d", rawXMLMsgLen)
  150. return
  151. }
  152. appIDOffset := 20 + rawXMLMsgLen
  153. if len(plaintext) <= appIDOffset {
  154. err = fmt.Errorf("msg length too large: %d", rawXMLMsgLen)
  155. return
  156. }
  157. random = plaintext[:16:20]
  158. rawXMLMsg = plaintext[20:appIDOffset:appIDOffset]
  159. appID = plaintext[appIDOffset:]
  160. return
  161. }
  162. // 把整数 n 格式化成 4 字节的网络字节序
  163. func encodeNetworkByteOrder(orderBytes []byte, n uint32) {
  164. orderBytes[0] = byte(n >> 24)
  165. orderBytes[1] = byte(n >> 16)
  166. orderBytes[2] = byte(n >> 8)
  167. orderBytes[3] = byte(n)
  168. }
  169. // 从 4 字节的网络字节序里解析出整数
  170. func decodeNetworkByteOrder(orderBytes []byte) (n uint32) {
  171. return uint32(orderBytes[0])<<24 |
  172. uint32(orderBytes[1])<<16 |
  173. uint32(orderBytes[2])<<8 |
  174. uint32(orderBytes[3])
  175. }
  176. // CalculateSign 计算签名
  177. func CalculateSign(content, signType, key string) (string, error) {
  178. var h hash.Hash
  179. if signType == SignTypeHMACSHA256 {
  180. h = hmac.New(sha256.New, []byte(key))
  181. } else {
  182. h = md5.New()
  183. }
  184. if _, err := h.Write([]byte(content)); err != nil {
  185. return ``, err
  186. }
  187. return strings.ToUpper(hex.EncodeToString(h.Sum(nil))), nil
  188. }
  189. // ParamSign 计算所传参数的签名
  190. func ParamSign(p map[string]string, key string) (string, error) {
  191. bizKey := "&key=" + key
  192. str := OrderParam(p, bizKey)
  193. var signType string
  194. switch p["sign_type"] {
  195. case SignTypeMD5, SignTypeHMACSHA256:
  196. signType = p["sign_type"]
  197. case ``:
  198. signType = SignTypeMD5
  199. default:
  200. return ``, errors.New(`invalid sign_type`)
  201. }
  202. return CalculateSign(str, signType, key)
  203. }