9 changed files with 171 additions and 849 deletions
@ -1,12 +0,0 @@ |
|||
package app |
|||
|
|||
type Config struct { |
|||
Host string |
|||
Port int |
|||
User string |
|||
Pass string |
|||
ServerId int |
|||
|
|||
LogFile string |
|||
Position int |
|||
} |
|||
@ -1,414 +0,0 @@ |
|||
package app |
|||
|
|||
import ( |
|||
"bufio" |
|||
"bytes" |
|||
"context" |
|||
"crypto/sha1" |
|||
"encoding/binary" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"net" |
|||
"os" |
|||
"time" |
|||
|
|||
"github.com/go-mysql-org/go-mysql/replication" |
|||
) |
|||
|
|||
const ( |
|||
MinProtocolVersion byte = 10 |
|||
|
|||
OK_HEADER byte = 0x00 |
|||
ERR_HEADER byte = 0xff |
|||
EOF_HEADER byte = 0xfe |
|||
LocalInFile_HEADER byte = 0xfb |
|||
) |
|||
|
|||
const MaxPayloadLength = 1<<24 - 1 |
|||
|
|||
type Server struct { |
|||
Cfg *Config |
|||
Ctx context.Context |
|||
conn net.Conn |
|||
io *PacketIo |
|||
registerSucc bool |
|||
} |
|||
|
|||
func (s *Server) Run() { |
|||
defer func() { |
|||
s.Quit() |
|||
}() |
|||
|
|||
s.dump() |
|||
} |
|||
|
|||
func (s *Server) dump() { |
|||
err := s.handshake() |
|||
if err != nil { |
|||
panic(err) |
|||
} |
|||
s.invalidChecksum() |
|||
fmt.Println("dump ...") |
|||
s.register() |
|||
s.writeDumpCommand() |
|||
parser := replication.NewBinlogParser() |
|||
for { |
|||
//time.Sleep(2 * time.Second)
|
|||
//s.query("select 1")
|
|||
|
|||
data, err := s.io.readPacket() |
|||
if err != nil || len(data) == 0 { |
|||
continue |
|||
} |
|||
|
|||
//s.Quit()
|
|||
|
|||
if data[0] == OK_HEADER { |
|||
//skip ok
|
|||
data = data[1:] |
|||
if e, err := parser.Parse(data); err == nil { |
|||
e.Dump(os.Stdout) |
|||
} else { |
|||
fmt.Println(err) |
|||
} |
|||
} else { |
|||
s.io.HandleError(data) |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (s *Server) invalidChecksum() { |
|||
sql := `SET @master_binlog_checksum='NONE'` |
|||
if err := s.query(sql); err != nil { |
|||
fmt.Println(err) |
|||
} |
|||
//must read from tcp connection , either will be blocked
|
|||
_, _ = s.io.readPacket() |
|||
} |
|||
|
|||
func (s *Server) handshake() error { |
|||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", s.Cfg.Host, s.Cfg.Port), 10*time.Second) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
tc := conn.(*net.TCPConn) |
|||
tc.SetKeepAlive(true) |
|||
tc.SetNoDelay(true) |
|||
s.conn = tc |
|||
|
|||
s.io = &PacketIo{} |
|||
s.io.r = bufio.NewReaderSize(s.conn, 16*1024) |
|||
s.io.w = tc |
|||
|
|||
data, err := s.io.readPacket() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
if data[0] == ERR_HEADER { |
|||
return errors.New("error packet") |
|||
} |
|||
|
|||
if data[0] < MinProtocolVersion { |
|||
return fmt.Errorf("version is too lower, current:%d", data[0]) |
|||
} |
|||
|
|||
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 |
|||
connId := uint32(binary.LittleEndian.Uint32(data[pos : pos+4])) |
|||
pos += 4 |
|||
salt := data[pos : pos+8] |
|||
|
|||
pos += 8 + 1 |
|||
capability := uint32(binary.LittleEndian.Uint16(data[pos : pos+2])) |
|||
|
|||
pos += 2 |
|||
|
|||
var status uint16 |
|||
var pluginName string |
|||
if len(data) > pos { |
|||
//skip charset
|
|||
pos++ |
|||
status = binary.LittleEndian.Uint16(data[pos : pos+2]) |
|||
pos += 2 |
|||
capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | capability |
|||
pos += 2 |
|||
|
|||
pos += 10 + 1 |
|||
salt = append(salt, data[pos:pos+12]...) |
|||
pos += 13 |
|||
|
|||
if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { |
|||
pluginName = string(data[pos : pos+end]) |
|||
} else { |
|||
pluginName = string(data[pos:]) |
|||
} |
|||
} |
|||
|
|||
fmt.Printf("conn_id:%v, status:%d, plugin:%v\n", connId, status, pluginName) |
|||
|
|||
//write
|
|||
capability = 500357 |
|||
length := 4 + 4 + 1 + 23 |
|||
length += len(s.Cfg.User) + 1 |
|||
|
|||
pass := []byte(s.Cfg.Pass) |
|||
auth := calPassword(salt[:20], pass) |
|||
length += 1 + len(auth) |
|||
data = make([]byte, length+4) |
|||
|
|||
data[4] = byte(capability) |
|||
data[5] = byte(capability >> 8) |
|||
data[6] = byte(capability >> 16) |
|||
data[7] = byte(capability >> 24) |
|||
|
|||
//utf8
|
|||
data[12] = byte(33) |
|||
pos = 13 + 23 |
|||
if len(s.Cfg.User) > 0 { |
|||
pos += copy(data[pos:], s.Cfg.User) |
|||
} |
|||
|
|||
pos++ |
|||
data[pos] = byte(len(auth)) |
|||
pos += 1 + copy(data[pos+1:], auth) |
|||
|
|||
err = s.io.writePacket(data) |
|||
if err != nil { |
|||
return fmt.Errorf("write auth packet error") |
|||
} |
|||
|
|||
pk, err := s.io.readPacket() |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
if pk[0] == OK_HEADER { |
|||
fmt.Println("handshake ok ") |
|||
return nil |
|||
} else if pk[0] == ERR_HEADER { |
|||
s.io.HandleError(pk) |
|||
return errors.New("handshake error ") |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func (s *Server) writeDumpCommand() { |
|||
s.io.seq = 0 |
|||
data := make([]byte, 4+1+4+2+4+len(s.Cfg.LogFile)) |
|||
pos := 4 |
|||
data[pos] = 18 //dump binlog
|
|||
pos++ |
|||
binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.Position)) |
|||
pos += 4 |
|||
|
|||
//dump command flag
|
|||
binary.LittleEndian.PutUint16(data[pos:], 0) |
|||
pos += 2 |
|||
|
|||
binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.ServerId)) |
|||
pos += 4 |
|||
|
|||
copy(data[pos:], s.Cfg.LogFile) |
|||
|
|||
s.io.writePacket(data) |
|||
//ok
|
|||
res, _ := s.io.readPacket() |
|||
if res[0] == OK_HEADER { |
|||
fmt.Println("send dump command return ok.") |
|||
} else { |
|||
s.io.HandleError(res) |
|||
} |
|||
} |
|||
|
|||
func (s *Server) register() { |
|||
s.io.seq = 0 |
|||
hostname, _ := os.Hostname() |
|||
data := make([]byte, 4+1+4+1+len(hostname)+1+len(s.Cfg.User)+1+len(s.Cfg.Pass)+2+4+4) |
|||
pos := 4 |
|||
data[pos] = 21 //register slave command
|
|||
pos++ |
|||
binary.LittleEndian.PutUint32(data[pos:], uint32(s.Cfg.ServerId)) |
|||
pos += 4 |
|||
|
|||
data[pos] = uint8(len(hostname)) |
|||
pos++ |
|||
n := copy(data[pos:], hostname) |
|||
pos += n |
|||
|
|||
data[pos] = uint8(len(s.Cfg.User)) |
|||
pos++ |
|||
n = copy(data[pos:], s.Cfg.User) |
|||
pos += n |
|||
|
|||
data[pos] = uint8(len(s.Cfg.Pass)) |
|||
pos++ |
|||
n = copy(data[pos:], s.Cfg.Pass) |
|||
pos += n |
|||
|
|||
binary.LittleEndian.PutUint16(data[pos:], uint16(s.Cfg.Port)) |
|||
pos += 2 |
|||
|
|||
binary.LittleEndian.PutUint32(data[pos:], 0) |
|||
pos += 4 |
|||
|
|||
//master id = 0
|
|||
binary.LittleEndian.PutUint32(data[pos:], 0) |
|||
|
|||
s.io.writePacket(data) |
|||
|
|||
//ok
|
|||
res, _ := s.io.readPacket() |
|||
if res[0] == OK_HEADER { |
|||
fmt.Println("register success.") |
|||
s.registerSucc = true |
|||
} else { |
|||
s.io.HandleError(data) |
|||
} |
|||
} |
|||
|
|||
func (s *Server) writeCommand(command byte) { |
|||
s.io.seq = 0 |
|||
_ = s.io.writePacket([]byte{ |
|||
0x01, //1 byte long
|
|||
0x00, |
|||
0x00, |
|||
0x00, //seq
|
|||
command, |
|||
}) |
|||
} |
|||
|
|||
func (s *Server) query(q string) error { |
|||
s.io.seq = 0 |
|||
length := len(q) + 1 |
|||
data := make([]byte, length+4) |
|||
data[4] = 3 |
|||
copy(data[5:], q) |
|||
return s.io.writePacket(data) |
|||
} |
|||
|
|||
func (s *Server) Quit() { |
|||
//quit
|
|||
s.writeCommand(byte(1)) |
|||
//maybe only close
|
|||
if err := s.conn.Close(); nil != err { |
|||
fmt.Printf("error in close :%v\n", err) |
|||
} |
|||
} |
|||
|
|||
type PacketIo struct { |
|||
r *bufio.Reader |
|||
w io.Writer |
|||
seq uint8 |
|||
} |
|||
|
|||
func (p *PacketIo) readPacket() ([]byte, error) { |
|||
//to read header
|
|||
header := []byte{0, 0, 0, 0} |
|||
if _, err := io.ReadFull(p.r, header); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) |
|||
if length == 0 { |
|||
p.seq++ |
|||
return []byte{}, nil |
|||
} |
|||
|
|||
if length == 1 { |
|||
return nil, fmt.Errorf("invalid payload") |
|||
} |
|||
|
|||
seq := uint8(header[3]) |
|||
if p.seq != seq { |
|||
return nil, fmt.Errorf("invalid seq %d", seq) |
|||
} |
|||
|
|||
p.seq++ |
|||
data := make([]byte, length) |
|||
if _, err := io.ReadFull(p.r, data); err != nil { |
|||
return nil, err |
|||
} else { |
|||
if length < MaxPayloadLength { |
|||
return data, nil |
|||
} |
|||
var buf []byte |
|||
buf, err = p.readPacket() |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
if len(buf) == 0 { |
|||
return data, nil |
|||
} else { |
|||
return append(data, buf...), nil |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (p *PacketIo) writePacket(data []byte) error { |
|||
length := len(data) - 4 |
|||
if length >= MaxPayloadLength { |
|||
data[0] = 0xff |
|||
data[1] = 0xff |
|||
data[2] = 0xff |
|||
data[3] = p.seq |
|||
|
|||
if n, err := p.w.Write(data[:4+MaxPayloadLength]); err != nil { |
|||
return fmt.Errorf("write find error") |
|||
} else if n != 4+MaxPayloadLength { |
|||
return fmt.Errorf("not equal max pay load length") |
|||
} else { |
|||
p.seq++ |
|||
length -= MaxPayloadLength |
|||
data = data[MaxPayloadLength:] |
|||
} |
|||
} |
|||
|
|||
data[0] = byte(length) |
|||
data[1] = byte(length >> 8) |
|||
data[2] = byte(length >> 16) |
|||
data[3] = p.seq |
|||
|
|||
if n, err := p.w.Write(data); err != nil { |
|||
return errors.New("write find error") |
|||
} else if n != len(data) { |
|||
return errors.New("not equal length") |
|||
} else { |
|||
p.seq++ |
|||
return nil |
|||
} |
|||
} |
|||
|
|||
func calPassword(scramble, password []byte) []byte { |
|||
crypt := sha1.New() |
|||
crypt.Write(password) |
|||
stage1 := crypt.Sum(nil) |
|||
|
|||
crypt.Reset() |
|||
crypt.Write(stage1) |
|||
hash := crypt.Sum(nil) |
|||
|
|||
crypt.Reset() |
|||
crypt.Write(scramble) |
|||
crypt.Write(hash) |
|||
scramble = crypt.Sum(nil) |
|||
|
|||
for i := range scramble { |
|||
scramble[i] ^= stage1[i] |
|||
} |
|||
|
|||
return scramble |
|||
} |
|||
|
|||
func (p *PacketIo) HandleError(data []byte) { |
|||
pos := 1 |
|||
code := binary.LittleEndian.Uint16(data[pos:]) |
|||
pos += 2 |
|||
pos++ |
|||
state := string(data[pos : pos+5]) |
|||
pos += 5 |
|||
msg := string(data[pos:]) |
|||
fmt.Printf("code:%d, state:%s, msg:%s\n", code, state, msg) |
|||
} |
|||
@ -1,4 +1,4 @@ |
|||
set GOOS=linux |
|||
set GOARCH=amd64 |
|||
set CGO_ENABLED=0 |
|||
go build -o mblog mbmain.go |
|||
go build -o mblog main.go |
|||
@ -0,0 +1,162 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"flag" |
|||
"fmt" |
|||
"log" |
|||
"strings" |
|||
|
|||
"github.com/go-mysql-org/go-mysql/canal" |
|||
"github.com/go-mysql-org/go-mysql/mysql" |
|||
"github.com/go-mysql-org/go-mysql/replication" |
|||
) |
|||
|
|||
// 自定义事件处理结构体
|
|||
type MyEventHandler struct { |
|||
canal.DummyEventHandler |
|||
} |
|||
|
|||
// 处理行事件
|
|||
func (h *MyEventHandler) OnRow(e *canal.RowsEvent) error { |
|||
table := e.Table |
|||
sql := "" |
|||
|
|||
switch e.Action { |
|||
case canal.InsertAction: |
|||
// 处理插入事件
|
|||
columns := make([]string, len(table.Columns)) |
|||
for i, col := range table.Columns { |
|||
columns[i] = col.Name |
|||
} |
|||
|
|||
values := make([]string, len(e.Rows[0])) |
|||
for i, val := range e.Rows[0] { |
|||
values[i] = formatValue(val) |
|||
} |
|||
|
|||
sql = fmt.Sprintf("INSERT INTO `%s`.`%s` (%s) VALUES (%s);", |
|||
table.Schema, table.Name, |
|||
strings.Join(columns, ", "), |
|||
strings.Join(values, ", ")) |
|||
|
|||
case canal.UpdateAction: |
|||
// 处理更新事件
|
|||
oldRow := e.Rows[0] |
|||
newRow := e.Rows[1] |
|||
|
|||
sets := make([]string, 0) |
|||
wheres := make([]string, 0) |
|||
|
|||
for i, col := range table.Columns { |
|||
if oldRow[i] != newRow[i] { |
|||
sets = append(sets, fmt.Sprintf("`%s` = %s", col.Name, formatValue(newRow[i]))) |
|||
} |
|||
wheres = append(wheres, fmt.Sprintf("`%s` = %s", col.Name, formatValue(oldRow[i]))) |
|||
} |
|||
|
|||
sql = fmt.Sprintf("UPDATE `%s`.`%s` SET %s WHERE %s;", |
|||
table.Schema, table.Name, |
|||
strings.Join(sets, ", "), |
|||
strings.Join(wheres, " AND ")) |
|||
|
|||
case canal.DeleteAction: |
|||
// 处理删除事件
|
|||
wheres := make([]string, len(table.Columns)) |
|||
for i, col := range table.Columns { |
|||
wheres[i] = fmt.Sprintf("`%s` = %s", col.Name, formatValue(e.Rows[0][i])) |
|||
} |
|||
|
|||
sql = fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE %s;", |
|||
table.Schema, table.Name, |
|||
strings.Join(wheres, " AND ")) |
|||
} |
|||
|
|||
if sql != "" { |
|||
fmt.Println(sql) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// 处理DDL事件
|
|||
func (h *MyEventHandler) OnDDL(nextPos mysql.Position, queryEvent *replication.QueryEvent) error { |
|||
sql := string(queryEvent.Query) |
|||
if sql != "" { |
|||
fmt.Println(sql + ";") |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
// 格式化值为SQL表示形式
|
|||
func formatValue(value interface{}) string { |
|||
if value == nil { |
|||
return "NULL" |
|||
} |
|||
|
|||
switch v := value.(type) { |
|||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64: |
|||
return fmt.Sprintf("%v", v) |
|||
case []byte: |
|||
// 处理二进制数据
|
|||
return fmt.Sprintf("X'%x'", v) |
|||
case string: |
|||
// 转义单引号
|
|||
return fmt.Sprintf("'%s'", strings.ReplaceAll(v, "'", "''")) |
|||
default: |
|||
return fmt.Sprintf("'%v'", v) |
|||
} |
|||
} |
|||
|
|||
func main() { |
|||
// 解析命令行参数
|
|||
host := flag.String("host", "localhost", "MySQL主机地址") |
|||
port := flag.Uint("port", 3306, "MySQL端口") |
|||
user := flag.String("user", "root", "MySQL用户名") |
|||
password := flag.String("password", "", "MySQL密码") |
|||
serverID := flag.Uint("server-id", 1001, "客户端服务器ID") |
|||
flavor := flag.String("flavor", "mysql", "数据库类型 (mysql或mariadb)") |
|||
startFile := flag.String("start-file", "", "开始读取的binlog文件名") |
|||
startPos := flag.Uint("start-pos", 4, "开始读取的位置") |
|||
|
|||
flag.Parse() |
|||
|
|||
// 创建canal配置
|
|||
cfg := canal.NewDefaultConfig() |
|||
cfg.Addr = fmt.Sprintf("%s:%d", *host, *port) |
|||
cfg.User = *user |
|||
cfg.Password = *password |
|||
cfg.ServerID = uint32(*serverID) |
|||
cfg.Flavor = *flavor |
|||
|
|||
// 设置需要监听的数据库和表,默认监听所有
|
|||
// cfg.Dump.TableDB = "test_db"
|
|||
// cfg.Dump.Tables = []string{"test_table"}
|
|||
|
|||
// 创建canal实例
|
|||
c, err := canal.NewCanal(cfg) |
|||
if err != nil { |
|||
log.Fatalf("创建canal实例失败: %v", err) |
|||
} |
|||
|
|||
// 设置事件处理器
|
|||
c.SetEventHandler(&MyEventHandler{}) |
|||
|
|||
// 设置起始位置
|
|||
var pos mysql.Position |
|||
if *startFile != "" { |
|||
pos = mysql.Position{Name: *startFile, Pos: uint32(*startPos)} |
|||
} else { |
|||
// 如果未指定,从当前位置开始
|
|||
pos, err = c.GetMasterPos() |
|||
if err != nil { |
|||
log.Fatalf("获取主库位置失败: %v", err) |
|||
} |
|||
fmt.Printf("从binlog位置 %s:%d 开始读取\n", pos.Name, pos.Pos) |
|||
} |
|||
|
|||
// 开始同步
|
|||
err = c.RunFrom(pos) |
|||
if err != nil { |
|||
log.Fatalf("同步失败: %v", err) |
|||
} |
|||
} |
|||
Binary file not shown.
@ -1,344 +0,0 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"context" |
|||
"encoding/json" |
|||
"flag" |
|||
"fmt" |
|||
"log" |
|||
"os" |
|||
"os/signal" |
|||
"syscall" |
|||
"time" |
|||
|
|||
"github.com/go-mysql-org/go-mysql/mysql" |
|||
"github.com/go-mysql-org/go-mysql/replication" |
|||
) |
|||
|
|||
// SQLWriter 用于异步输出SQL语句
|
|||
type SQLWriter struct { |
|||
sqlCh chan string |
|||
} |
|||
|
|||
func NewSQLWriter() *SQLWriter { |
|||
return &SQLWriter{ |
|||
sqlCh: make(chan string, 1000), // 缓冲通道
|
|||
} |
|||
} |
|||
|
|||
// Start 启动异步SQL写入器
|
|||
func (w *SQLWriter) Start() { |
|||
go func() { |
|||
for sql := range w.sqlCh { |
|||
// 这里可以替换为实际写入文件、发送到消息队列等操作
|
|||
fmt.Printf("[%s] %s\n", time.Now().Format("2006-01-02 15:04:05"), sql) |
|||
} |
|||
}() |
|||
} |
|||
|
|||
// Write 异步写入SQL
|
|||
func (w *SQLWriter) Write(sql string) { |
|||
select { |
|||
case w.sqlCh <- sql: |
|||
// 成功写入通道
|
|||
default: |
|||
log.Println("警告: SQL通道拥堵,丢弃SQL语句:", sql) |
|||
} |
|||
} |
|||
|
|||
// Stop 停止写入器
|
|||
func (w *SQLWriter) Stop() { |
|||
close(w.sqlCh) |
|||
} |
|||
|
|||
type PositionManager struct { |
|||
filePath string |
|||
} |
|||
|
|||
func (pm *PositionManager) Save(pos mysql.Position) error { |
|||
data, _ := json.Marshal(pos) |
|||
return os.WriteFile(pm.filePath, data, 0644) |
|||
} |
|||
|
|||
func (pm *PositionManager) Load() (mysql.Position, error) { |
|||
data, err := os.ReadFile(pm.filePath) |
|||
if err != nil { |
|||
return mysql.Position{}, err |
|||
} |
|||
|
|||
var pos mysql.Position |
|||
err = json.Unmarshal(data, &pos) |
|||
return pos, err |
|||
} |
|||
|
|||
// 程序启动参数
|
|||
var user = flag.String("user", "root", "MySQL user, must have replication privilege") |
|||
var password = flag.String("password", "****", "MySQL password") |
|||
|
|||
// 从 arg参数中获取配置信息
|
|||
func main() { |
|||
flag.Parse() |
|||
|
|||
// 创建SQL写入器
|
|||
sqlWriter := NewSQLWriter() |
|||
sqlWriter.Start() |
|||
defer sqlWriter.Stop() |
|||
|
|||
// 配置Binlog同步器
|
|||
cfg := replication.BinlogSyncerConfig{ |
|||
ServerID: 100, // 唯一的ServerID
|
|||
Flavor: "mysql", |
|||
Host: "localhost", |
|||
Port: 3306, |
|||
User: *user, |
|||
Password: *password, |
|||
Charset: "utf8mb4", |
|||
} |
|||
|
|||
syncer := replication.NewBinlogSyncer(cfg) |
|||
defer syncer.Close() |
|||
|
|||
// 获取当前Binlog位置(可选)
|
|||
// 也可以从指定的位置开始,如 mysql.Position{Name: "mysql-bin.000001", Pos: 4}
|
|||
position := mysql.Position{Name: "", Pos: 4} |
|||
|
|||
streamer, err := syncer.StartSync(position) |
|||
if err != nil { |
|||
log.Fatalf("Failed to start sync: %v", err) |
|||
} |
|||
|
|||
log.Println("开始监听MySQL Binlog...") |
|||
|
|||
// 处理优雅退出
|
|||
signalCh := make(chan os.Signal, 1) |
|||
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM) |
|||
|
|||
ctx, cancel := context.WithCancel(context.Background()) |
|||
defer cancel() |
|||
|
|||
// 启动事件处理循环
|
|||
go eventLoop(ctx, streamer, sqlWriter) |
|||
|
|||
<-signalCh |
|||
log.Println("收到退出信号,停止监听...") |
|||
} |
|||
|
|||
// 事件处理循环
|
|||
func eventLoop(ctx context.Context, streamer *replication.BinlogStreamer, writer *SQLWriter) { |
|||
for { |
|||
select { |
|||
case <-ctx.Done(): |
|||
return |
|||
default: |
|||
ev, err := streamer.GetEvent(ctx) |
|||
if err != nil { |
|||
if err == context.Canceled { |
|||
return |
|||
} |
|||
log.Printf("获取Binlog事件错误: %v", err) |
|||
continue |
|||
} |
|||
|
|||
// 解析Binlog事件
|
|||
if err := parseBinlogEvent(ev, writer); err != nil { |
|||
log.Printf("解析Binlog事件错误: %v", err) |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
// 解析Binlog事件并生成SQL语句
|
|||
func parseBinlogEvent(ev *replication.BinlogEvent, writer *SQLWriter) error { |
|||
event := ev.Header.EventType |
|||
|
|||
switch event { |
|||
case replication.WRITE_ROWS_EVENTv1, replication.WRITE_ROWS_EVENTv2: |
|||
return handleWriteRows(ev, writer) |
|||
case replication.DELETE_ROWS_EVENTv1, replication.DELETE_ROWS_EVENTv2: |
|||
return handleDeleteRows(ev, writer) |
|||
case replication.UPDATE_ROWS_EVENTv1, replication.UPDATE_ROWS_EVENTv2: |
|||
return handleUpdateRows(ev, writer) |
|||
case replication.QUERY_EVENT: |
|||
return handleQueryEvent(ev, writer) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// 处理INSERT事件
|
|||
func handleWriteRows(ev *replication.BinlogEvent, writer *SQLWriter) error { |
|||
rowsEvent, ok := ev.Event.(*replication.RowsEvent) |
|||
if !ok { |
|||
return fmt.Errorf("类型断言失败: 期望*replication.RowsEvent") |
|||
} |
|||
|
|||
tableName := string(rowsEvent.Table.Table) |
|||
schemaName := string(rowsEvent.Table.Schema) |
|||
|
|||
for _, row := range rowsEvent.Rows { |
|||
columns := make([]string, len(row)) |
|||
values := make([]interface{}, len(row)) |
|||
|
|||
for i, value := range row { |
|||
columns[i] = fmt.Sprintf("column%d", i) |
|||
values[i] = value |
|||
} |
|||
|
|||
sql := generateInsertSQL(schemaName, tableName, columns, values) |
|||
writer.Write(sql) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// 处理DELETE事件
|
|||
func handleDeleteRows(ev *replication.BinlogEvent, writer *SQLWriter) error { |
|||
rowsEvent, ok := ev.Event.(*replication.RowsEvent) |
|||
if !ok { |
|||
return fmt.Errorf("类型断言失败: 期望*replication.RowsEvent") |
|||
} |
|||
|
|||
tableName := string(rowsEvent.Table.Table) |
|||
schemaName := string(rowsEvent.Table.Schema) |
|||
|
|||
for _, row := range rowsEvent.Rows { |
|||
whereClause := generateWhereClause(row) |
|||
sql := fmt.Sprintf("DELETE FROM `%s`.`%s` WHERE %s", schemaName, tableName, whereClause) |
|||
writer.Write(sql) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// 处理UPDATE事件
|
|||
func handleUpdateRows(ev *replication.BinlogEvent, writer *SQLWriter) error { |
|||
rowsEvent, ok := ev.Event.(*replication.RowsEvent) |
|||
if !ok { |
|||
return fmt.Errorf("类型断言失败: 期望*replication.RowsEvent") |
|||
} |
|||
|
|||
tableName := string(rowsEvent.Table.Table) |
|||
schemaName := string(rowsEvent.Table.Schema) |
|||
|
|||
// Rows是成对出现的: [旧值, 新值]
|
|||
for i := 0; i < len(rowsEvent.Rows); i += 2 { |
|||
if i+1 >= len(rowsEvent.Rows) { |
|||
break |
|||
} |
|||
|
|||
oldRow := rowsEvent.Rows[i] |
|||
newRow := rowsEvent.Rows[i+1] |
|||
|
|||
setClause := generateSetClause(oldRow, newRow) |
|||
whereClause := generateWhereClause(oldRow) |
|||
|
|||
sql := fmt.Sprintf("UPDATE `%s`.`%s` SET %s WHERE %s", |
|||
schemaName, tableName, setClause, whereClause) |
|||
writer.Write(sql) |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// 处理QUERY事件(DDL语句)
|
|||
func handleQueryEvent(ev *replication.BinlogEvent, writer *SQLWriter) error { |
|||
queryEvent, ok := ev.Event.(*replication.QueryEvent) |
|||
if !ok { |
|||
return fmt.Errorf("类型断言失败: 期望*replication.QueryEvent") |
|||
} |
|||
|
|||
sql := string(queryEvent.Query) |
|||
writer.Write("-- DDL操作: " + sql) |
|||
|
|||
return nil |
|||
} |
|||
|
|||
// 生成INSERT SQL语句
|
|||
func generateInsertSQL(schema, table string, columns []string, values []interface{}) string { |
|||
valueStrs := make([]string, len(values)) |
|||
for i, v := range values { |
|||
valueStrs[i] = formatValue(v) |
|||
} |
|||
|
|||
return fmt.Sprintf("INSERT INTO `%s`.`%s` VALUES (%s);", |
|||
schema, table, joinValues(valueStrs)) |
|||
} |
|||
|
|||
// 生成WHERE子句
|
|||
func generateWhereClause(row []interface{}) string { |
|||
parts := make([]string, len(row)) |
|||
for i, value := range row { |
|||
parts[i] = fmt.Sprintf("column%d = %s", i, formatValue(value)) |
|||
} |
|||
return joinValues(parts) |
|||
} |
|||
|
|||
// 生成SET子句
|
|||
func generateSetClause(oldRow, newRow []interface{}) string { |
|||
parts := make([]string, len(newRow)) |
|||
for i, newValue := range newRow { |
|||
oldValue := oldRow[i] |
|||
// 只更新有变化的字段
|
|||
if fmt.Sprintf("%v", oldValue) != fmt.Sprintf("%v", newValue) { |
|||
parts[i] = fmt.Sprintf("column%d = %s", i, formatValue(newValue)) |
|||
} |
|||
} |
|||
|
|||
// 过滤空值
|
|||
var nonEmptyParts []string |
|||
for _, part := range parts { |
|||
if part != "" { |
|||
nonEmptyParts = append(nonEmptyParts, part) |
|||
} |
|||
} |
|||
|
|||
return joinValues(nonEmptyParts) |
|||
} |
|||
|
|||
// 格式化值
|
|||
func formatValue(value interface{}) string { |
|||
if value == nil { |
|||
return "NULL" |
|||
} |
|||
|
|||
switch v := value.(type) { |
|||
case string: |
|||
return fmt.Sprintf("'%s'", escapeString(v)) |
|||
case []byte: |
|||
return fmt.Sprintf("'%s'", escapeString(string(v))) |
|||
default: |
|||
return fmt.Sprintf("%v", v) |
|||
} |
|||
} |
|||
|
|||
// 转义字符串
|
|||
func escapeString(s string) string { |
|||
// 简单的转义,实际应用中可能需要更完整的实现
|
|||
return s |
|||
} |
|||
|
|||
// 连接值
|
|||
func joinValues(values []string) string { |
|||
if len(values) == 0 { |
|||
return "" |
|||
} |
|||
|
|||
result := values[0] |
|||
for i := 1; i < len(values); i++ { |
|||
result += ", " + values[i] |
|||
} |
|||
return result |
|||
} |
|||
|
|||
// withRetry 重试操作直到成功或超过最大重试次数
|
|||
func withRetry(operation func() error, maxRetries int) error { |
|||
var err error |
|||
for i := 0; i < maxRetries; i++ { |
|||
if err = operation(); err == nil { |
|||
return nil |
|||
} |
|||
log.Printf("操作失败,尝试重连 (%d/%d): %v", i+1, maxRetries, err) |
|||
time.Sleep(time.Duration(i+1) * time.Second) |
|||
} |
|||
return fmt.Errorf("超过最大重试次数: %v", err) |
|||
} |
|||
@ -1,78 +0,0 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"context" |
|||
"flag" |
|||
"fmt" |
|||
"os" |
|||
|
|||
"github.com/go-mysql-org/go-mysql/mysql" |
|||
"github.com/go-mysql-org/go-mysql/replication" |
|||
"github.com/pingcap/errors" |
|||
) |
|||
|
|||
var host = flag.String("host", "127.0.0.1", "MySQL host") |
|||
var port = flag.Int("port", 3306, "MySQL port") |
|||
var user = flag.String("user", "root", "MySQL user, must have replication privilege") |
|||
var password = flag.String("password", "****", "MySQL password") |
|||
|
|||
var flavor = flag.String("flavor", "mysql", "Flavor: mysql or mariadb") |
|||
|
|||
var file = flag.String("file", "mysql-bin.000032", "Binlog filename") |
|||
var pos = flag.Int("pos", 3070, "Binlog position") |
|||
|
|||
var semiSync = flag.Bool("semisync", false, "Support semi sync") |
|||
var backupPath = flag.String("backup_path", "", "backup path to store binlog files") |
|||
|
|||
var rawMode = flag.Bool("raw", false, "Use raw mode") |
|||
|
|||
func main() { |
|||
flag.Parse() |
|||
|
|||
cfg := replication.BinlogSyncerConfig{ |
|||
ServerID: 101, |
|||
Flavor: *flavor, |
|||
|
|||
Host: *host, |
|||
Port: uint16(*port), |
|||
User: *user, |
|||
Password: *password, |
|||
RawModeEnabled: *rawMode, |
|||
SemiSyncEnabled: *semiSync, |
|||
UseDecimal: true, |
|||
} |
|||
|
|||
b := replication.NewBinlogSyncer(cfg) |
|||
|
|||
pos := mysql.Position{Name: *file, Pos: uint32(*pos)} |
|||
if len(*backupPath) > 0 { |
|||
// Backup will always use RawMode. |
|||
err := b.StartBackup(*backupPath, pos, 0) |
|||
if err != nil { |
|||
fmt.Printf("Start backup error: %v\n", errors.ErrorStack(err)) |
|||
return |
|||
} |
|||
} else { |
|||
s, err := b.StartSync(pos) |
|||
if err != nil { |
|||
fmt.Printf("Start sync error: %v\n", errors.ErrorStack(err)) |
|||
return |
|||
} |
|||
|
|||
for { |
|||
e, err := s.GetEvent(context.Background()) |
|||
if err != nil { |
|||
// Try to output all left events |
|||
events := s.DumpEvents() |
|||
for _, e := range events { |
|||
e.Dump(os.Stdout) |
|||
} |
|||
fmt.Printf("Get event error: %v\n", errors.ErrorStack(err)) |
|||
return |
|||
} |
|||
|
|||
e.Dump(os.Stdout) |
|||
} |
|||
} |
|||
|
|||
} |
|||
Loading…
Reference in new issue