rsa_ext.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. package utils
  2. import (
  3. "bytes"
  4. "crypto/rand"
  5. "crypto/rsa"
  6. "crypto/x509"
  7. "encoding/pem"
  8. "errors"
  9. "io"
  10. "io/ioutil"
  11. "math/big"
  12. )
  13. var (
  14. ErrDataToLarge = errors.New("message too long for RSA public key size")
  15. ErrDataLen = errors.New("data length error")
  16. ErrDataBroken = errors.New("data broken, first byte is not zero")
  17. ErrKeyPairDismatch = errors.New("data is not encrypted by the private key")
  18. ErrDecryption = errors.New("decryption error")
  19. ErrPublicKey = errors.New("get public key error")
  20. ErrPrivateKey = errors.New("get private key error")
  21. ErrGetRedisFail = errors.New("get redis pool fail")
  22. ErrGetRedisConnectFail = errors.New("get redis conn fail")
  23. )
  24. // 设置公钥
  25. func getPubKey(publickey []byte) (*rsa.PublicKey, error) {
  26. // decode public key
  27. block, _ := pem.Decode(publickey)
  28. if block == nil {
  29. return nil, errors.New("get public key error")
  30. }
  31. // x509 parse public key
  32. pub, err := x509.ParsePKIXPublicKey(block.Bytes)
  33. if err != nil {
  34. pub, err = x509.ParsePKCS1PublicKey(block.Bytes)
  35. if err != nil {
  36. return nil, err
  37. }
  38. }
  39. return pub.(*rsa.PublicKey), err
  40. }
  41. // 设置私钥
  42. func getPriKey(privatekey []byte) (*rsa.PrivateKey, error) {
  43. block, _ := pem.Decode(privatekey)
  44. if block == nil {
  45. return nil, errors.New("get private key error")
  46. }
  47. // pri, err := x509.ParsePKCS1PrivateKey(block.Bytes)
  48. // if err == nil {
  49. // return pri, nil
  50. // }
  51. pri2, err := x509.ParsePKCS8PrivateKey(block.Bytes)
  52. if err != nil {
  53. return nil, err
  54. }
  55. return pri2.(*rsa.PrivateKey), nil
  56. }
  57. // 公钥加密或解密byte
  58. func pubKeyByte(pub *rsa.PublicKey, in []byte, isEncrytp bool) ([]byte, error) {
  59. k := (pub.N.BitLen() + 7) / 8
  60. if isEncrytp {
  61. k = k - 11
  62. }
  63. if len(in) <= k {
  64. if isEncrytp {
  65. return rsa.EncryptPKCS1v15(rand.Reader, pub, in)
  66. } else {
  67. return pubKeyDecrypt(pub, in)
  68. }
  69. } else {
  70. iv := make([]byte, k)
  71. out := bytes.NewBuffer(iv)
  72. if err := pubKeyIO(pub, bytes.NewReader(in), out, isEncrytp); err != nil {
  73. return nil, err
  74. }
  75. return ioutil.ReadAll(out)
  76. }
  77. }
  78. // 私钥加密或解密byte
  79. func priKeyByte(pri *rsa.PrivateKey, in []byte, isEncrytp bool) ([]byte, error) {
  80. k := (pri.N.BitLen() + 7) / 8
  81. if isEncrytp {
  82. k = k - 11
  83. }
  84. if len(in) <= k {
  85. if isEncrytp {
  86. return priKeyEncrypt(rand.Reader, pri, in)
  87. } else {
  88. return rsa.DecryptPKCS1v15(rand.Reader, pri, in)
  89. }
  90. } else {
  91. iv := make([]byte, k)
  92. out := bytes.NewBuffer(iv)
  93. if err := priKeyIO(pri, bytes.NewReader(in), out, isEncrytp); err != nil {
  94. return nil, err
  95. }
  96. return ioutil.ReadAll(out)
  97. }
  98. }
  99. // 公钥加密或解密Reader
  100. func pubKeyIO(pub *rsa.PublicKey, in io.Reader, out io.Writer, isEncrytp bool) (err error) {
  101. k := (pub.N.BitLen()) / 8
  102. if isEncrytp {
  103. k = k - 11
  104. }
  105. buf := make([]byte, k)
  106. var b []byte
  107. size := 0
  108. for {
  109. size, err = in.Read(buf)
  110. if err != nil {
  111. if err == io.EOF {
  112. return nil
  113. }
  114. return err
  115. }
  116. if size < k {
  117. b = buf[:size]
  118. } else {
  119. b = buf
  120. }
  121. if isEncrytp {
  122. b, err = rsa.EncryptPKCS1v15(rand.Reader, pub, b)
  123. } else {
  124. b, err = pubKeyDecrypt(pub, b)
  125. }
  126. if err != nil {
  127. return err
  128. }
  129. if _, err = out.Write(b); err != nil {
  130. return err
  131. }
  132. }
  133. return nil
  134. }
  135. // 私钥加密或解密Reader
  136. func priKeyIO(pri *rsa.PrivateKey, r io.Reader, w io.Writer, isEncrytp bool) (err error) {
  137. k := (pri.N.BitLen()) / 8
  138. if isEncrytp {
  139. k = k - 11
  140. }
  141. buf := make([]byte, k)
  142. var b []byte
  143. size := 0
  144. for {
  145. size, err = r.Read(buf)
  146. if err != nil {
  147. if err == io.EOF {
  148. return nil
  149. }
  150. return err
  151. }
  152. if size < k {
  153. b = buf[:size]
  154. } else {
  155. b = buf
  156. }
  157. if isEncrytp {
  158. b, err = priKeyEncrypt(rand.Reader, pri, b)
  159. } else {
  160. b, err = rsa.DecryptPKCS1v15(rand.Reader, pri, b)
  161. }
  162. if err != nil {
  163. return err
  164. }
  165. if _, err = w.Write(b); err != nil {
  166. return err
  167. }
  168. }
  169. return nil
  170. }
  171. // 公钥解密
  172. func pubKeyDecrypt(pub *rsa.PublicKey, data []byte) ([]byte, error) {
  173. k := (pub.N.BitLen()) / 8
  174. if k != len(data) {
  175. return nil, ErrDataLen
  176. }
  177. m := new(big.Int).SetBytes(data)
  178. if m.Cmp(pub.N) > 0 {
  179. return nil, ErrDataToLarge
  180. }
  181. m.Exp(m, big.NewInt(int64(pub.E)), pub.N)
  182. d := leftPad(m.Bytes(), k)
  183. if d[0] != 0 {
  184. return nil, ErrDataBroken
  185. }
  186. if d[1] != 0 && d[1] != 1 {
  187. return nil, ErrKeyPairDismatch
  188. }
  189. var i = 2
  190. for ; i < len(d); i++ {
  191. if d[i] == 0 {
  192. break
  193. }
  194. }
  195. i++
  196. if i == len(d) {
  197. return nil, nil
  198. }
  199. return d[i:], nil
  200. }
  201. // 私钥加密
  202. func priKeyEncrypt(rand io.Reader, priv *rsa.PrivateKey, hashed []byte) ([]byte, error) {
  203. tLen := len(hashed)
  204. k := (priv.N.BitLen()) / 8
  205. if k < tLen+11 {
  206. return nil, ErrDataLen
  207. }
  208. em := make([]byte, k)
  209. em[1] = 1
  210. for i := 2; i < k-tLen-1; i++ {
  211. em[i] = 0xff
  212. }
  213. copy(em[k-tLen:k], hashed)
  214. m := new(big.Int).SetBytes(em)
  215. c, err := decrypt(rand, priv, m)
  216. if err != nil {
  217. return nil, err
  218. }
  219. copyWithLeftPad(em, c.Bytes())
  220. return em, nil
  221. }
  222. // 从crypto/rsa复制
  223. var bigZero = big.NewInt(0)
  224. var bigOne = big.NewInt(1)
  225. // 从crypto/rsa复制
  226. func encrypt(c *big.Int, pub *rsa.PublicKey, m *big.Int) *big.Int {
  227. e := big.NewInt(int64(pub.E))
  228. c.Exp(m, e, pub.N)
  229. return c
  230. }
  231. // 从crypto/rsa复制
  232. func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, err error) {
  233. if c.Cmp(priv.N) > 0 {
  234. err = ErrDecryption
  235. return
  236. }
  237. var ir *big.Int
  238. if random != nil {
  239. var r *big.Int
  240. for {
  241. r, err = rand.Int(random, priv.N)
  242. if err != nil {
  243. return
  244. }
  245. if r.Cmp(bigZero) == 0 {
  246. r = bigOne
  247. }
  248. var ok bool
  249. ir, ok = modInverse(r, priv.N)
  250. if ok {
  251. break
  252. }
  253. }
  254. bigE := big.NewInt(int64(priv.E))
  255. rpowe := new(big.Int).Exp(r, bigE, priv.N)
  256. cCopy := new(big.Int).Set(c)
  257. cCopy.Mul(cCopy, rpowe)
  258. cCopy.Mod(cCopy, priv.N)
  259. c = cCopy
  260. }
  261. if priv.Precomputed.Dp == nil {
  262. m = new(big.Int).Exp(c, priv.D, priv.N)
  263. } else {
  264. m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0])
  265. m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1])
  266. m.Sub(m, m2)
  267. if m.Sign() < 0 {
  268. m.Add(m, priv.Primes[0])
  269. }
  270. m.Mul(m, priv.Precomputed.Qinv)
  271. m.Mod(m, priv.Primes[0])
  272. m.Mul(m, priv.Primes[1])
  273. m.Add(m, m2)
  274. for i, values := range priv.Precomputed.CRTValues {
  275. prime := priv.Primes[2+i]
  276. m2.Exp(c, values.Exp, prime)
  277. m2.Sub(m2, m)
  278. m2.Mul(m2, values.Coeff)
  279. m2.Mod(m2, prime)
  280. if m2.Sign() < 0 {
  281. m2.Add(m2, prime)
  282. }
  283. m2.Mul(m2, values.R)
  284. m.Add(m, m2)
  285. }
  286. }
  287. if ir != nil {
  288. m.Mul(m, ir)
  289. m.Mod(m, priv.N)
  290. }
  291. return
  292. }
  293. // 从crypto/rsa复制
  294. func copyWithLeftPad(dest, src []byte) {
  295. numPaddingBytes := len(dest) - len(src)
  296. for i := 0; i < numPaddingBytes; i++ {
  297. dest[i] = 0
  298. }
  299. copy(dest[numPaddingBytes:], src)
  300. }
  301. // 从crypto/rsa复制
  302. func nonZeroRandomBytes(s []byte, rand io.Reader) (err error) {
  303. _, err = io.ReadFull(rand, s)
  304. if err != nil {
  305. return
  306. }
  307. for i := 0; i < len(s); i++ {
  308. for s[i] == 0 {
  309. _, err = io.ReadFull(rand, s[i:i+1])
  310. if err != nil {
  311. return
  312. }
  313. s[i] ^= 0x42
  314. }
  315. }
  316. return
  317. }
  318. // 从crypto/rsa复制
  319. func leftPad(input []byte, size int) (out []byte) {
  320. n := len(input)
  321. if n > size {
  322. n = size
  323. }
  324. out = make([]byte, size)
  325. copy(out[len(out)-n:], input)
  326. return
  327. }
  328. // 从crypto/rsa复制
  329. func modInverse(a, n *big.Int) (ia *big.Int, ok bool) {
  330. g := new(big.Int)
  331. x := new(big.Int)
  332. y := new(big.Int)
  333. g.GCD(x, y, a, n)
  334. if g.Cmp(bigOne) != 0 {
  335. return
  336. }
  337. if x.Cmp(bigOne) < 0 {
  338. x.Add(x, n)
  339. }
  340. return x, true
  341. }