Skip to content

Instantly share code, notes, and snippets.

@lunaspeed
Created May 29, 2020 02:03
Show Gist options
  • Save lunaspeed/a9561696636ef437023d3b6bdd7713d0 to your computer and use it in GitHub Desktop.
Save lunaspeed/a9561696636ef437023d3b6bdd7713d0 to your computer and use it in GitHub Desktop.
Ruby CSRF token generation and verification in Golang
package csrf
import(
"encoding/base64"
"encoding/hex"
"fmt"
"github.com/theckman/go-securerandom"
)
const AuthenticityTokenLength = 32
//mainly minor modification and addition to https://devfun.tw/t/topic/1981
//follow ruby's way to generate CSRF token https://ruby-china.org/topics/35199
func generateCsrfToken(sessionCsrfToken string) string {
oneTimePad, err := securerandom.Bytes(AuthenticityTokenLength)
if err != nil {
fmt.Println("secure random in unavailable in environment")
return ""
}
tokenBytes, err := base64.StdEncoding.DecodeString(sessionCsrfToken)
if err != nil {
fmt.Println("failed to decode csrf token in session:", err.Error())
return ""
}
encryptedCsrfToken := xorBytes(oneTimePad, tokenBytes)
return base64.StdEncoding.EncodeToString(append(oneTimePad, encryptedCsrfToken...))
}
func verifyCsrfToken(sessionCsrfToken, webToken string) bool {
maskedToken, err := base64.StdEncoding.DecodeString(webToken)
if err != nil {
fmt.Println("decode base64 webCsrfToken fail", err.Error())
return false
}
if len(maskedToken) != session.AuthenticityTokenLength*2 {
fmt.Println("len fail , token is malformed")
}
sourceToken := hex.EncodeToString(xorBytes(maskedToken[:AuthenticityTokenLength], maskedToken[AuthenticityTokenLength:]))
token, err := base64.StdEncoding.DecodeString(sessionCsrfToken)
if err != nil {
fmt.Println("decode base64 sessionCsrfToken fail", err.Error())
return false
}
sessionToken := hex.EncodeToString(token)
if sourceToken != sessionToken {
fmt.Println("sourceToken != sessionToken,", sourceToken, sessionToken)
return false
}
return true
}
func xorBytes(s1 []byte, s2 []byte) []byte {
newBytes := make([]byte, len(s2))
for index, s2temp := range s2 {
newBytes[index] = s1[index] ^ s2temp
}
return newBytes
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment