You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
414 lines
7.8 KiB
414 lines
7.8 KiB
package app
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha1"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/go-mysql-org/go-mysql/replication"
|
|
)
|
|
|
|
const (
|
|
MinProtocolVersion byte = 10
|
|
|
|
OK_HEADER byte = 0x00
|
|
ERR_HEADER byte = 0xff
|
|
EOF_HEADER byte = 0xfe
|
|
LocalInFile_HEADER byte = 0xfb
|
|
)
|
|
|
|
const MaxPayloadLength = 1<<24 - 1
|
|
|
|
type Server struct {
|
|
Cfg *Config
|
|
Ctx context.Context
|
|
conn net.Conn
|
|
io *PacketIo
|
|
registerSucc bool
|
|
}
|
|
|
|
func (s *Server) Run() {
|
|
defer func() {
|
|
s.Quit()
|
|
}()
|
|
|
|
s.dump()
|
|
}
|
|
|
|
func (s *Server) dump() {
|
|
err := s.handshake()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
s.invalidChecksum()
|
|
fmt.Println("dump ...")
|
|
s.register()
|
|
s.writeDumpCommand()
|
|
parser := replication.NewBinlogParser()
|
|
for {
|
|
//time.Sleep(2 * time.Second)
|
|
//s.query("select 1")
|
|
|
|
data, err := s.io.readPacket()
|
|
if err != nil || len(data) == 0 {
|
|
continue
|
|
}
|
|
|
|
//s.Quit()
|
|
|
|
if data[0] == OK_HEADER {
|
|
//skip ok
|
|
data = data[1:]
|
|
if e, err := parser.Parse(data); err == nil {
|
|
e.Dump(os.Stdout)
|
|
} else {
|
|
fmt.Println(err)
|
|
}
|
|
} else {
|
|
s.io.HandleError(data)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) invalidChecksum() {
|
|
sql := `SET @master_binlog_checksum='NONE'`
|
|
if err := s.query(sql); err != nil {
|
|
fmt.Println(err)
|
|
}
|
|
//must read from tcp connection , either will be blocked
|
|
_, _ = s.io.readPacket()
|
|
}
|
|
|
|
func (s *Server) handshake() error {
|
|
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", s.Cfg.Host, s.Cfg.Port), 10*time.Second)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tc := conn.(*net.TCPConn)
|
|
tc.SetKeepAlive(true)
|
|
tc.SetNoDelay(true)
|
|
s.conn = tc
|
|
|
|
s.io = &PacketIo{}
|
|
s.io.r = bufio.NewReaderSize(s.conn, 16*1024)
|
|
s.io.w = tc
|
|
|
|
data, err := s.io.readPacket()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if data[0] == ERR_HEADER {
|
|
return errors.New("error packet")
|
|
}
|
|
|
|
if data[0] < MinProtocolVersion {
|
|
return fmt.Errorf("version is too lower, current:%d", data[0])
|
|
}
|
|
|
|
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1
|
|
connId := uint32(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
|
pos += 4
|
|
salt := data[pos : pos+8]
|
|
|
|
pos += 8 + 1
|
|
capability := uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
|
|
|
pos += 2
|
|
|
|
var status uint16
|
|
var pluginName string
|
|
if len(data) > pos {
|
|
//skip charset
|
|
pos++
|
|
status = binary.LittleEndian.Uint16(data[pos : pos+2])
|
|
pos += 2
|
|
capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | capability
|
|
pos += 2
|
|
|
|
pos += 10 + 1
|
|
salt = append(salt, data[pos:pos+12]...)
|
|
pos += 13
|
|
|
|
if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
|
|
pluginName = string(data[pos : pos+end])
|
|
} else {
|
|
pluginName = string(data[pos:])
|
|
}
|
|
}
|
|
|
|
fmt.Printf("conn_id:%v, status:%d, plugin:%v\n", connId, status, pluginName)
|
|
|
|
//write
|
|
capability = 500357
|
|
length := 4 + 4 + 1 + 23
|
|
length += len(s.Cfg.User) + 1
|
|
|
|
pass := []byte(s.Cfg.Pass)
|
|
auth := calPassword(salt[:20], pass)
|
|
length += 1 + len(auth)
|
|
data = make([]byte, length+4)
|
|
|
|
data[4] = byte(capability)
|
|
data[5] = byte(capability >> 8)
|
|
data[6] = byte(capability >> 16)
|
|
data[7] = byte(capability >> 24)
|
|
|
|
//utf8
|
|
data[12] = byte(33)
|
|
pos = 13 + 23
|
|
if len(s.Cfg.User) > 0 {
|
|
pos += copy(data[pos:], s.Cfg.User)
|
|
}
|
|
|
|
pos++
|
|
data[pos] = byte(len(auth))
|
|
pos += 1 + copy(data[pos+1:], auth)
|
|
|
|
err = s.io.writePacket(data)
|
|
if err != nil {
|
|
return fmt.Errorf("write auth packet error")
|
|
}
|
|
|
|
pk, err := s.io.readPacket()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if pk[0] == OK_HEADER {
|
|
fmt.Println("handshake ok ")
|
|
return nil
|
|
} else if pk[0] == ERR_HEADER {
|
|
s.io.HandleError(pk)
|
|
return errors.New("handshake error ")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) writeDumpCommand() {
|
|
s.io.seq = 0
|
|
data := make([]byte, 4+1+4+2+4+len(s.Cfg.LogFile))
|
|
pos := 4
|
|
data[pos] = 18 //dump binlog
|
|
pos++
|
|
binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.Position))
|
|
pos += 4
|
|
|
|
//dump command flag
|
|
binary.LittleEndian.PutUint16(data[pos:], 0)
|
|
pos += 2
|
|
|
|
binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.ServerId))
|
|
pos += 4
|
|
|
|
copy(data[pos:], s.Cfg.LogFile)
|
|
|
|
s.io.writePacket(data)
|
|
//ok
|
|
res, _ := s.io.readPacket()
|
|
if res[0] == OK_HEADER {
|
|
fmt.Println("send dump command return ok.")
|
|
} else {
|
|
s.io.HandleError(res)
|
|
}
|
|
}
|
|
|
|
func (s *Server) register() {
|
|
s.io.seq = 0
|
|
hostname, _ := os.Hostname()
|
|
data := make([]byte, 4+1+4+1+len(hostname)+1+len(s.Cfg.User)+1+len(s.Cfg.Pass)+2+4+4)
|
|
pos := 4
|
|
data[pos] = 21 //register slave command
|
|
pos++
|
|
binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.ServerId))
|
|
pos += 4
|
|
|
|
data[pos] = uint8(len(hostname))
|
|
pos++
|
|
n := copy(data[pos:], hostname)
|
|
pos += n
|
|
|
|
data[pos] = uint8(len(s.Cfg.User))
|
|
pos++
|
|
n = copy(data[pos:], s.Cfg.User)
|
|
pos += n
|
|
|
|
data[pos] = uint8(len(s.Cfg.Pass))
|
|
pos++
|
|
n = copy(data[pos:], s.Cfg.Pass)
|
|
pos += n
|
|
|
|
binary.LittleEndian.PutUint16(data[pos:], uint16(s.Cfg.Port))
|
|
pos += 2
|
|
|
|
binary.LittleEndian.PutUint32(data[pos:], 0)
|
|
pos += 4
|
|
|
|
//master id = 0
|
|
binary.LittleEndian.PutUint32(data[pos:], 0)
|
|
|
|
s.io.writePacket(data)
|
|
|
|
//ok
|
|
res, _ := s.io.readPacket()
|
|
if res[0] == OK_HEADER {
|
|
fmt.Println("register success.")
|
|
s.registerSucc = true
|
|
} else {
|
|
s.io.HandleError(data)
|
|
}
|
|
}
|
|
|
|
func (s *Server) writeCommand(command byte) {
|
|
s.io.seq = 0
|
|
_ = s.io.writePacket([]byte{
|
|
0x01, //1 byte long
|
|
0x00,
|
|
0x00,
|
|
0x00, //seq
|
|
command,
|
|
})
|
|
}
|
|
|
|
func (s *Server) query(q string) error {
|
|
s.io.seq = 0
|
|
length := len(q) + 1
|
|
data := make([]byte, length+4)
|
|
data[4] = 3
|
|
copy(data[5:], q)
|
|
return s.io.writePacket(data)
|
|
}
|
|
|
|
func (s *Server) Quit() {
|
|
//quit
|
|
s.writeCommand(byte(1))
|
|
//maybe only close
|
|
if err := s.conn.Close(); nil != err {
|
|
fmt.Printf("error in close :%v\n", err)
|
|
}
|
|
}
|
|
|
|
type PacketIo struct {
|
|
r *bufio.Reader
|
|
w io.Writer
|
|
seq uint8
|
|
}
|
|
|
|
func (p *PacketIo) readPacket() ([]byte, error) {
|
|
//to read header
|
|
header := []byte{0, 0, 0, 0}
|
|
if _, err := io.ReadFull(p.r, header); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
|
|
if length == 0 {
|
|
p.seq++
|
|
return []byte{}, nil
|
|
}
|
|
|
|
if length == 1 {
|
|
return nil, fmt.Errorf("invalid payload")
|
|
}
|
|
|
|
seq := uint8(header[3])
|
|
if p.seq != seq {
|
|
return nil, fmt.Errorf("invalid seq %d", seq)
|
|
}
|
|
|
|
p.seq++
|
|
data := make([]byte, length)
|
|
if _, err := io.ReadFull(p.r, data); err != nil {
|
|
return nil, err
|
|
} else {
|
|
if length < MaxPayloadLength {
|
|
return data, nil
|
|
}
|
|
var buf []byte
|
|
buf, err = p.readPacket()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(buf) == 0 {
|
|
return data, nil
|
|
} else {
|
|
return append(data, buf...), nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *PacketIo) writePacket(data []byte) error {
|
|
length := len(data) - 4
|
|
if length >= MaxPayloadLength {
|
|
data[0] = 0xff
|
|
data[1] = 0xff
|
|
data[2] = 0xff
|
|
data[3] = p.seq
|
|
|
|
if n, err := p.w.Write(data[:4+MaxPayloadLength]); err != nil {
|
|
return fmt.Errorf("write find error")
|
|
} else if n != 4+MaxPayloadLength {
|
|
return fmt.Errorf("not equal max pay load length")
|
|
} else {
|
|
p.seq++
|
|
length -= MaxPayloadLength
|
|
data = data[MaxPayloadLength:]
|
|
}
|
|
}
|
|
|
|
data[0] = byte(length)
|
|
data[1] = byte(length >> 8)
|
|
data[2] = byte(length >> 16)
|
|
data[3] = p.seq
|
|
|
|
if n, err := p.w.Write(data); err != nil {
|
|
return errors.New("write find error")
|
|
} else if n != len(data) {
|
|
return errors.New("not equal length")
|
|
} else {
|
|
p.seq++
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func calPassword(scramble, password []byte) []byte {
|
|
crypt := sha1.New()
|
|
crypt.Write(password)
|
|
stage1 := crypt.Sum(nil)
|
|
|
|
crypt.Reset()
|
|
crypt.Write(stage1)
|
|
hash := crypt.Sum(nil)
|
|
|
|
crypt.Reset()
|
|
crypt.Write(scramble)
|
|
crypt.Write(hash)
|
|
scramble = crypt.Sum(nil)
|
|
|
|
for i := range scramble {
|
|
scramble[i] ^= stage1[i]
|
|
}
|
|
|
|
return scramble
|
|
}
|
|
|
|
func (p *PacketIo) HandleError(data []byte) {
|
|
pos := 1
|
|
code := binary.LittleEndian.Uint16(data[pos:])
|
|
pos += 2
|
|
pos++
|
|
state := string(data[pos : pos+5])
|
|
pos += 5
|
|
msg := string(data[pos:])
|
|
fmt.Printf("code:%d, state:%s, msg:%s\n", code, state, msg)
|
|
}
|
|
|