自动更新管控端
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

461 lines
10 KiB

package main
import (
"database/sql"
"encoding/csv"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
_ "github.com/go-sql-driver/mysql"
)
// BackupConfig 备份配置
type BackupConfig struct {
Host string
Port int
User string
Password string
Database string
BackupDir string
KeepDays int
Async bool
}
// BackupResult 备份结果
type BackupResult struct {
Success bool
Message string
StartTime time.Time
EndTime time.Time
Filename string
}
// BackupManager 备份管理器
type BackupManager struct {
config BackupConfig
db *sql.DB
mutex sync.Mutex
running bool
}
// NewBackupManager 创建新的备份管理器
func NewBackupManager(config BackupConfig) (*BackupManager, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
config.User, config.Password, config.Host, config.Port, config.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return nil, fmt.Errorf("无法连接数据库: %v", err)
}
// 测试连接
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("数据库连接失败: %v", err)
}
return &BackupManager{
config: config,
db: db,
}, nil
}
// 获取数据库中所有表
func (m *BackupManager) getTables() ([]string, error) {
query := "SHOW TABLES"
rows, err := m.db.Query(query)
if err != nil {
return nil, fmt.Errorf("查询表失败: %v", err)
}
defer rows.Close()
var tables []string
for rows.Next() {
var table string
if err := rows.Scan(&table); err != nil {
return nil, fmt.Errorf("扫描表名失败: %v", err)
}
tables = append(tables, table)
}
return tables, nil
}
// 获取表结构
func (m *BackupManager) getTableSchema(table string) (string, error) {
query := fmt.Sprintf("SHOW CREATE TABLE `%s`", table)
rows, err := m.db.Query(query)
if err != nil {
return "", fmt.Errorf("查询表结构失败: %v", err)
}
defer rows.Close()
var (
tbl string
sql string
)
if rows.Next() {
if err := rows.Scan(&tbl, &sql); err != nil {
return "", fmt.Errorf("扫描表结构失败: %v", err)
}
}
return sql, nil
}
// 备份表数据到CSV文件
func (m *BackupManager) backupTableData(table string, dataFile string) error {
// 获取表列信息
columnsQuery := fmt.Sprintf("DESCRIBE `%s`", table)
rows, err := m.db.Query(columnsQuery)
if err != nil {
return fmt.Errorf("查询表列信息失败: %v", err)
}
defer rows.Close()
var columns []string
for rows.Next() {
var (
field, typ, null, key, extra string
defaultVal sql.NullString
)
if err := rows.Scan(&field, &typ, &null, &key, &defaultVal, &extra); err != nil {
return fmt.Errorf("扫描表列信息失败: %v", err)
}
columns = append(columns, field)
}
// 打开数据文件
file, err := os.Create(dataFile)
if err != nil {
return fmt.Errorf("创建数据文件失败: %v", err)
}
defer file.Close()
writer := csv.NewWriter(file)
defer writer.Flush()
// 写入列名
if err := writer.Write(columns); err != nil {
return fmt.Errorf("写入列名失败: %v", err)
}
// 分页查询数据
limit := 1000
offset := 0
columnsStr := "`" + strings.Join(columns, "`, `") + "`"
for {
dataQuery := fmt.Sprintf("SELECT %s FROM `%s` LIMIT %d OFFSET %d",
columnsStr, table, limit, offset)
dataRows, err := m.db.Query(dataQuery)
if err != nil {
return fmt.Errorf("查询表数据失败: %v", err)
}
// 获取列类型信息
columnTypes, err := dataRows.ColumnTypes()
if err != nil {
dataRows.Close()
return fmt.Errorf("获取列类型失败: %v", err)
}
rowCount := 0
for dataRows.Next() {
rowCount++
// 创建一个接口切片来存储一行数据
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
// 扫描行数据
if err := dataRows.Scan(valuePtrs...); err != nil {
dataRows.Close()
return fmt.Errorf("扫描表数据失败: %v", err)
}
// 处理数据并写入CSV
csvRow := make([]string, len(columns))
for i, val := range values {
colType := columnTypes[i].DatabaseTypeName()
csvRow[i] = formatValue(val, colType)
}
if err := writer.Write(csvRow); err != nil {
dataRows.Close()
return fmt.Errorf("写入数据到CSV失败: %v", err)
}
}
dataRows.Close()
// 如果获取的行数小于limit,说明已经到最后一页
if rowCount < limit {
break
}
offset += limit
}
return nil
}
// 格式化值为字符串
func formatValue(value interface{}, colType string) string {
if value == nil {
return "NULL"
}
switch v := value.(type) {
case []byte:
// 处理二进制数据
return string(v)
case string:
return v
case int64:
return strconv.FormatInt(v, 10)
case float64:
return strconv.FormatFloat(v, 'f', -1, 64)
case bool:
if v {
return "1"
}
return "0"
case time.Time:
return v.Format("2006-01-02 15:04:05")
default:
return fmt.Sprintf("%v", v)
}
}
// 创建备份目录
func (m *BackupManager) createBackupDir(timestamp string) (string, error) {
backupPath := filepath.Join(m.config.BackupDir,
fmt.Sprintf("%s_%s", m.config.Database, timestamp))
if err := os.MkdirAll(backupPath, 0755); err != nil {
return "", fmt.Errorf("创建备份目录失败: %v", err)
}
return backupPath, nil
}
// 备份表结构到文件
func (m *BackupManager) backupTableSchema(table string, schemaFile string) error {
schema, err := m.getTableSchema(table)
if err != nil {
return err
}
// 写入表结构到文件
if err := os.WriteFile(schemaFile, []byte(schema+";\n"), 0644); err != nil {
return fmt.Errorf("写入表结构失败: %v", err)
}
return nil
}
// 清理旧备份
func (m *BackupManager) cleanOldBackups() error {
if m.config.KeepDays <= 0 {
return nil
}
cutoffTime := time.Now().AddDate(0, 0, -m.config.KeepDays)
// 读取备份目录
entries, err := os.ReadDir(m.config.BackupDir)
if err != nil {
return fmt.Errorf("读取备份目录失败: %v", err)
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
// 检查目录名是否符合备份命名规范
if !strings.HasPrefix(entry.Name(), m.config.Database+"_") {
continue
}
// 获取目录修改时间
info, err := entry.Info()
if err != nil {
log.Printf("获取目录信息失败: %v", err)
continue
}
// 如果目录超过保留天数,删除它
if info.ModTime().Before(cutoffTime) {
path := filepath.Join(m.config.BackupDir, entry.Name())
if err := os.RemoveAll(path); err != nil {
log.Printf("删除旧备份失败: %v", err)
} else {
log.Printf("已删除旧备份: %s", path)
}
}
}
return nil
}
// 执行备份
func (m *BackupManager) performBackup() BackupResult {
result := BackupResult{
StartTime: time.Now(),
Success: false,
}
// 检查是否已经在运行备份
m.mutex.Lock()
if m.running {
m.mutex.Unlock()
result.Message = "备份正在进行中"
return result
}
m.running = true
m.mutex.Unlock()
defer func() {
m.mutex.Lock()
m.running = false
m.mutex.Unlock()
result.EndTime = time.Now()
}()
// 获取当前时间戳作为备份标识
timestamp := result.StartTime.Format("20060102150405")
// 创建备份目录
backupDir, err := m.createBackupDir(timestamp)
if err != nil {
result.Message = fmt.Sprintf("创建备份目录失败: %v", err)
return result
}
result.Filename = filepath.Base(backupDir)
// 获取所有表
tables, err := m.getTables()
if err != nil {
result.Message = fmt.Sprintf("获取表列表失败: %v", err)
return result
}
// 备份每个表
for _, table := range tables {
log.Printf("开始备份表: %s", table)
// 备份表结构
schemaFile := filepath.Join(backupDir, table+".sql")
if err := m.backupTableSchema(table, schemaFile); err != nil {
result.Message = fmt.Sprintf("备份表 %s 结构失败: %v", table, err)
return result
}
// 备份表数据
dataFile := filepath.Join(backupDir, table+".csv")
if err := m.backupTableData(table, dataFile); err != nil {
result.Message = fmt.Sprintf("备份表 %s 数据失败: %v", table, err)
return result
}
log.Printf("表 %s 备份完成", table)
}
// 清理旧备份
if err := m.cleanOldBackups(); err != nil {
log.Printf("清理旧备份时出错: %v", err)
}
result.Success = true
result.Message = fmt.Sprintf("数据库 %s 备份成功,存储在: %s", m.config.Database, backupDir)
return result
}
// StartBackup 启动备份(异步或同步)
func (m *BackupManager) StartBackup() (BackupResult, chan BackupResult) {
if m.config.Async {
resultChan := make(chan BackupResult, 1)
go func() {
result := m.performBackup()
resultChan <- result
close(resultChan)
}()
return BackupResult{Success: true, Message: "异步备份已启动"}, resultChan
}
result := m.performBackup()
return result, nil
}
// Close 关闭数据库连接
func (m *BackupManager) Close() error {
return m.db.Close()
}
func main() {
// 解析命令行参数
host := flag.String("host", "localhost", "MySQL主机地址")
port := flag.Int("port", 3306, "MySQL端口")
user := flag.String("user", "root", "MySQL用户名")
password := flag.String("password", "", "MySQL密码")
database := flag.String("db", "", "要备份的数据库名")
backupDir := flag.String("dir", "backups", "备份文件存储目录")
keepDays := flag.Int("keep", 7, "备份保留天数")
async := flag.Bool("async", false, "是否异步执行备份")
flag.Parse()
if *database == "" {
log.Fatal("必须指定要备份的数据库名 (-db 参数)")
}
// 创建备份配置
config := BackupConfig{
Host: *host,
Port: *port,
User: *user,
Password: *password,
Database: *database,
BackupDir: *backupDir,
KeepDays: *keepDays,
Async: *async,
}
// 创建备份管理器
manager, err := NewBackupManager(config)
if err != nil {
log.Fatalf("初始化备份管理器失败: %v", err)
}
defer manager.Close()
// 启动备份
log.Println("开始数据库备份...")
result, resultChan := manager.StartBackup()
log.Println(result.Message)
// 如果是异步备份,等待结果
if *async && resultChan != nil {
log.Println("等待备份完成...")
backupResult := <-resultChan
if backupResult.Success {
log.Printf("备份成功,耗时: %v", backupResult.EndTime.Sub(backupResult.StartTime))
log.Println(backupResult.Message)
} else {
log.Printf("备份失败: %s", backupResult.Message)
}
}
}