最近打算做一款类似腾讯<<脑力达人>>的h5游戏,之前打算用skynet来做,所以给skynet增加了websocket模块,
https://github.com/Skycrab/skynet_websocket。刚好最近在学习golang,考虑之下打算用golang来实现,说不定过段时间
还能整个golang游戏服务器。之前我一直认为Python是我的真爱,但现在真心喜欢golang,也许这也是弥补我静态语言
的缺失吧,虽然C++/C还算熟悉,但没有工程经验,始终觉得缺少点什么。我相信golang以后会在服务器领域有一席之地,
现在研究也算投资吧,等golang越来越成熟,gc越来越高效,会有很多转投golang的怀抱。
我始终相信,一门语言一种文化。当我写Python时,我很少会考虑效率,想的更多的是简洁与优雅实现; 但当我写golang时,
时不时会左右比较,在int32与int64之间徘徊,估算本次大概需要多少byte进行内存预分配。。。。在Python中即使你考虑了,
大多也是徒劳,语言本身很多没有提供。语言的文化,让我痴迷。
算上前一篇写的定时器(http://blog.csdn.net/yueguanghaidao/article/details/46290539)和本篇的websocket,还差不少东西才能组成游戏服务器,慢慢填坑吧。
有人说,golang的websocket很多,何必造轮子,但自己写的后期好优化,更新方便,造轮子是快速学习的途径,如果时间
允许,多多造轮子,会在中途收获很多。
github地址:https://github.com/Skycrab/code/tree/master/Go/websocket
首先看看如何使用:
package websocket
import (
"fmt"
"net/http"
"testing"
)
type MyHandler struct {
}
func (wd MyHandler) CheckOrigin(origin, host string) bool {
return true
}
func (wd MyHandler) OnOpen(ws *Websocket) {
fmt.Println("OnOpen")
ws.SendText([]byte("hello world from server"))
}
func (wd MyHandler) OnMessage(ws *Websocket, message []byte) {
fmt.Println("OnMessage:", string(message), len(message))
}
func (wd MyHandler) OnClose(ws *Websocket, code uint16, reason []byte) {
fmt.Println("OnClose", code, string(reason))
}
func (wd MyHandler) OnPong(ws *Websocket, data []byte) {
fmt.Println("OnPong:", string(data))
}
func TestWebsocket(t *testing.T) {
http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
fmt.Println("...")
var opt = Option{MyHandler{}, false}
ws, err := New(w, r, &opt)
if err != nil {
t.Fatal(err.Error())
}
ws.Start()
})
fmt.Println("server start")
http.ListenAndServe(":8001", nil)
}
使用方法和之前的类似,都是像tornado websocket执行方式。
MyHandler实现了WsHandler接口,如果你并不关注所有事件,可以继承WsDefaultHandler,WsDefaultHandler为所有的事件
提供了默认实现。
通过Option实现了默认参数功能,第二个参数代表是否mask发送的数据,客户端是需要的,服务端不需要,所以默认为false。
由于暂时没有websocket client的需求,所以没有提供,需要时再添加吧。
对比一下golang和lua的实现,代码行数并没有增加多少,golang是400行,lua是340行,不得不说golang编码效率的确
赶得上动态语言。在编写golang和lua实现时,我明显感觉到静态语言具有很大优势,lua出错提示不给力,这也是动态语言的
痛处吧。好消息是Python3.5提供了类型检查,我觉得的确是一大利器。
在这里把代码贴一下,方便查看。
package websocket
import (
"bufio"
"bytes"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/http"
"strings"
)
var (
ErrUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"")
ErrConnection = errors.New("\"Connection\" must be \"Upgrade\"")
ErrCrossOrigin = errors.New("Cross origin websockets not allowed")
ErrSecVersion = errors.New("HTTP/1.1 Upgrade Required\r\nSec-WebSocket-Version: 13\r\n\r\n")
ErrSecKey = errors.New("\"Sec-WebSocket-Key\" must not be nil")
ErrHijacker = errors.New("Not implement http.Hijacker")
)
var (
ErrReservedBits = errors.New("Reserved_bits show using undefined extensions")
ErrFrameOverload = errors.New("Control frame payload overload")
ErrFrameFragmented = errors.New("Control frame must not be fragmented")
ErrInvalidOpcode = errors.New("Invalid frame opcode")
)
var (
crlf = []byte("\r\n")
challengeKey = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
)
//referer https://github.com/Skycrab/skynet_websocket/blob/master/websocket.lua
type WsHandler interface {
CheckOrigin(origin, host string) bool
OnOpen(ws *Websocket)
OnMessage(ws *Websocket, message []byte)
OnClose(ws *Websocket, code uint16, reason []byte)
OnPong(ws *Websocket, data []byte)
}
type WsDefaultHandler struct {
checkOriginOr bool // 是否校验origin, default true
}
func (wd WsDefaultHandler) CheckOrigin(origin, host string) bool {
return true
}
func (wd WsDefaultHandler) OnOpen(ws *Websocket) {
}
func (wd WsDefaultHandler) OnMessage(ws *Websocket, message []byte) {
}
func (wd WsDefaultHandler) OnClose(ws *Websocket, code uint16, reason []byte) {
}
func (wd WsDefaultHandler) OnPong(ws *Websocket, data []byte) {
}
type Websocket struct {
conn net.Conn
rw *bufio.ReadWriter
handler WsHandler
clientTerminated bool
serverTerminated bool
maskOutgoing bool
}
type Option struct {
Handler WsHandler // 处理器, default WsDefaultHandler
MaskOutgoing bool //发送frame是否mask, default false
}
func challengeResponse(key, protocol string) []byte {
sha := sha1.New()
sha.Write([]byte(key))
sha.Write(challengeKey)
accept := base64.StdEncoding.EncodeToString(sha.Sum(nil))
buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")
buf.WriteString(accept)
buf.Write(crlf)
if protocol != "" {
buf.WriteString("Sec-WebSocket-Protocol: ")
buf.WriteString(protocol)
buf.Write(crlf)
}
buf.Write(crlf)
return buf.Bytes()
}
func acceptConnection(r *http.Request, h WsHandler) (challenge []byte, err error) {
//Upgrade header should be present and should be equal to WebSocket
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
return nil, ErrUpgrade
}
//Connection header should be upgrade. Some proxy servers/load balancers
// might mess with it.
if !strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") {
return nil, ErrConnection
}
// Handle WebSocket Origin naming convention differences
// The difference between version 8 and 13 is that in 8 the
// client sends a "Sec-Websocket-Origin" header and in 13 it's
// simply "Origin".
if r.Header.Get("Sec-Websocket-Version") != "13" {
return nil, ErrSecVersion
}
origin := r.Header.Get("Origin")
if origin == "" {
origin = r.Header.Get("Sec-Websocket-Origin")
}
if origin != "" && !h.CheckOrigin(origin, r.Header.Get("Host")) {
return nil, ErrCrossOrigin
}
key := r.Header.Get("Sec-Websocket-Key")
if key == "" {
return nil, ErrSecKey
}
protocol := r.Header.Get("Sec-Websocket-Protocol")
if protocol != "" {
idx := strings.IndexByte(protocol, ',')
if idx != -1 {
protocol = protocol[:idx]
}
}
return challengeResponse(key, protocol), nil
}
func websocketMask(mask []byte, data []byte) {
for i := range data {
data[i] ^= mask[i%4]
}
}
func New(w http.ResponseWriter, r *http.Request, opt *Option) (*Websocket, error) {
var h WsHandler
var maskOutgoing bool
if opt == nil {
h = WsDefaultHandler{true}
maskOutgoing = false
} else {
h = opt.Handler
maskOutgoing = opt.MaskOutgoing
}
challenge, err := acceptConnection(r, h)
if err != nil {
var code int
if err == ErrCrossOrigin {
code = 403
} else {
code = 400
}
w.WriteHeader(code)
w.Write([]byte(err.Error()))
return nil, err
}
hj, ok := w.(http.Hijacker)
if !ok {
return nil, ErrHijacker
}
conn, rw, err := hj.Hijack()
ws := new(Websocket)
ws.conn = conn
ws.rw = rw
ws.handler = h
ws.maskOutgoing = maskOutgoing
if _, err := ws.conn.Write(challenge); err != nil {
ws.conn.Close()
return nil, err
}
ws.handler.OnOpen(ws)
return ws, nil
}
func (ws *Websocket) read(buf []byte) error {
_, err := io.ReadFull(ws.rw, buf)
return err
}
func (ws *Websocket) SendFrame(fin bool, opcode byte, data []byte) error {
//max frame header may 14 length
buf := make([]byte, 0, len(data)+14)
var finBit, maskBit byte
if fin {
finBit = 0x80
} else {
finBit = 0
}
buf = append(buf, finBit|opcode)
length := len(data)
if ws.maskOutgoing {
maskBit = 0x80
} else {
maskBit = 0
}
if length < 126 {
buf = append(buf, byte(length)|maskBit)
} else if length < 0xFFFF {
buf = append(buf, 126|maskBit, 0, 0)
binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(length))
} else {
buf = append(buf, 127|maskBit, 0, 0, 0, 0, 0, 0, 0, 0)
binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(length))
}
if ws.maskOutgoing {
}
buf = append(buf, data...)
ws.rw.Write(buf)
return ws.rw.Flush()
}
func (ws *Websocket) SendText(data []byte) error {
return ws.SendFrame(true, 0x1, data)
}
func (ws *Websocket) SendBinary(data []byte) error {
return ws.SendFrame(true, 0x2, data)
}
func (ws *Websocket) SendPing(data []byte) error {
return ws.SendFrame(true, 0x9, data)
}
func (ws *Websocket) SendPong(data []byte) error {
return ws.SendFrame(true, 0xA, data)
}
func (ws *Websocket) Close(code uint16, reason []byte) {
if !ws.serverTerminated {
data := make([]byte, 0, len(reason)+2)
if code == 0 && reason != nil {
code = 1000
}
if code != 0 {
data = append(data, 0, 0)
binary.BigEndian.PutUint16(data, code)
}
if reason != nil {
data = append(data, reason...)
}
ws.SendFrame(true, 0x8, data)
ws.serverTerminated = true
}
if ws.clientTerminated {
ws.conn.Close()
}
}
func (ws *Websocket) RecvFrame() (final bool, message []byte, err error) { //text 数据报文
buf := make([]byte, 8, 8)
err = ws.read(buf[:2])
if err != nil {
return
}
header, payload := buf[0], buf[1]
final = header&0x80 != 0
reservedBits := header&0x70 != 0
frameOpcode := header & 0xf
frameOpcodeIsControl := frameOpcode&0x8 != 0
if reservedBits {
// client is using as-yet-undefined extensions
err = ErrReservedBits
return
}
maskFrame := payload&0x80 != 0
payloadlen := uint64(payload & 0x7f)
if frameOpcodeIsControl && payloadlen >= 126 {
err = ErrFrameOverload
return
}
if frameOpcodeIsControl && !final {
err = ErrFrameFragmented
return
}
//解析frame长度
var frameLength uint64
if payloadlen < 126 {
frameLength = payloadlen
} else if payloadlen == 126 {
err = ws.read(buf[:2])
if err != nil {
return
}
frameLength = uint64(binary.BigEndian.Uint16(buf[:2]))
} else { //payloadlen == 127
err = ws.read(buf[:8])
if err != nil {
return
}
frameLength = binary.BigEndian.Uint64(buf[:8])
}
frameMask := make([]byte, 4, 4)
if maskFrame {
err = ws.read(frameMask)
if err != nil {
return
}
}
// fmt.Println("final_frame:", final, "frame_opcode:", frameOpcode, "mask_frame:", maskFrame, "frame_length:", frameLength)
message = make([]byte, frameLength, frameLength)
if frameLength > 0 {
err = ws.read(message)
if err != nil {
return
}
}
if maskFrame && frameLength > 0 {
websocketMask(frameMask, message)
}
if !final {
return
} else {
switch frameOpcode {
case 0x1: //text
case 0x2: //binary
case 0x8: // close
var code uint16
var reason []byte
if frameLength >= 2 {
code = binary.BigEndian.Uint16(message[:2])
}
if frameLength > 2 {
reason = message[2:]
}
message = nil
ws.clientTerminated = true
ws.Close(0, nil)
ws.handler.OnClose(ws, code, reason)
case 0x9: //ping
message = nil
ws.SendPong(nil)
case 0xA:
ws.handler.OnPong(ws, message)
message = nil
default:
err = ErrInvalidOpcode
}
return
}
}
func (ws *Websocket) Recv() ([]byte, error) {
data := make([]byte, 0, 8)
for {
final, message, err := ws.RecvFrame()
if final {
data = append(data, message...)
break
} else {
data = append(data, message...)
}
if err != nil {
return data, err
}
}
if len(data) > 0 {
ws.handler.OnMessage(ws, data)
}
return data, nil
}
func (ws *Websocket) Start() {
for {
_, err := ws.Recv()
if err != nil {
ws.conn.Close()
}
}
}