package main import ( "database/sql" "encoding/csv" "flag" "fmt" "log" "os" "path/filepath" "strconv" "strings" "sync" "time" _ "github.com/go-sql-driver/mysql" ) // BackupConfig 备份配置 type BackupConfig struct { Host string Port int User string Password string Database string BackupDir string KeepDays int Async bool } // BackupResult 备份结果 type BackupResult struct { Success bool Message string StartTime time.Time EndTime time.Time Filename string } // BackupManager 备份管理器 type BackupManager struct { config BackupConfig db *sql.DB mutex sync.Mutex running bool } // NewBackupManager 创建新的备份管理器 func NewBackupManager(config BackupConfig) (*BackupManager, error) { dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true", config.User, config.Password, config.Host, config.Port, config.Database) db, err := sql.Open("mysql", dsn) if err != nil { return nil, fmt.Errorf("无法连接数据库: %v", err) } // 测试连接 if err := db.Ping(); err != nil { return nil, fmt.Errorf("数据库连接失败: %v", err) } return &BackupManager{ config: config, db: db, }, nil } // 获取数据库中所有表 func (m *BackupManager) getTables() ([]string, error) { query := "SHOW TABLES" rows, err := m.db.Query(query) if err != nil { return nil, fmt.Errorf("查询表失败: %v", err) } defer rows.Close() var tables []string for rows.Next() { var table string if err := rows.Scan(&table); err != nil { return nil, fmt.Errorf("扫描表名失败: %v", err) } tables = append(tables, table) } return tables, nil } // 获取表结构 func (m *BackupManager) getTableSchema(table string) (string, error) { query := fmt.Sprintf("SHOW CREATE TABLE `%s`", table) rows, err := m.db.Query(query) if err != nil { return "", fmt.Errorf("查询表结构失败: %v", err) } defer rows.Close() var ( tbl string sql string ) if rows.Next() { if err := rows.Scan(&tbl, &sql); err != nil { return "", fmt.Errorf("扫描表结构失败: %v", err) } } return sql, nil } // 备份表数据到CSV文件 func (m *BackupManager) backupTableData(table string, dataFile string) error { // 获取表列信息 columnsQuery := fmt.Sprintf("DESCRIBE `%s`", table) rows, err := m.db.Query(columnsQuery) if err != nil { return fmt.Errorf("查询表列信息失败: %v", err) } defer rows.Close() var columns []string for rows.Next() { var ( field, typ, null, key, extra string defaultVal sql.NullString ) if err := rows.Scan(&field, &typ, &null, &key, &defaultVal, &extra); err != nil { return fmt.Errorf("扫描表列信息失败: %v", err) } columns = append(columns, field) } // 打开数据文件 file, err := os.Create(dataFile) if err != nil { return fmt.Errorf("创建数据文件失败: %v", err) } defer file.Close() writer := csv.NewWriter(file) defer writer.Flush() // 写入列名 if err := writer.Write(columns); err != nil { return fmt.Errorf("写入列名失败: %v", err) } // 分页查询数据 limit := 1000 offset := 0 columnsStr := "`" + strings.Join(columns, "`, `") + "`" for { dataQuery := fmt.Sprintf("SELECT %s FROM `%s` LIMIT %d OFFSET %d", columnsStr, table, limit, offset) dataRows, err := m.db.Query(dataQuery) if err != nil { return fmt.Errorf("查询表数据失败: %v", err) } // 获取列类型信息 columnTypes, err := dataRows.ColumnTypes() if err != nil { dataRows.Close() return fmt.Errorf("获取列类型失败: %v", err) } rowCount := 0 for dataRows.Next() { rowCount++ // 创建一个接口切片来存储一行数据 values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range values { valuePtrs[i] = &values[i] } // 扫描行数据 if err := dataRows.Scan(valuePtrs...); err != nil { dataRows.Close() return fmt.Errorf("扫描表数据失败: %v", err) } // 处理数据并写入CSV csvRow := make([]string, len(columns)) for i, val := range values { colType := columnTypes[i].DatabaseTypeName() csvRow[i] = formatValue(val, colType) } if err := writer.Write(csvRow); err != nil { dataRows.Close() return fmt.Errorf("写入数据到CSV失败: %v", err) } } dataRows.Close() // 如果获取的行数小于limit,说明已经到最后一页 if rowCount < limit { break } offset += limit } return nil } // 格式化值为字符串 func formatValue(value interface{}, colType string) string { if value == nil { return "NULL" } switch v := value.(type) { case []byte: // 处理二进制数据 return string(v) case string: return v case int64: return strconv.FormatInt(v, 10) case float64: return strconv.FormatFloat(v, 'f', -1, 64) case bool: if v { return "1" } return "0" case time.Time: return v.Format("2006-01-02 15:04:05") default: return fmt.Sprintf("%v", v) } } // 创建备份目录 func (m *BackupManager) createBackupDir(timestamp string) (string, error) { backupPath := filepath.Join(m.config.BackupDir, fmt.Sprintf("%s_%s", m.config.Database, timestamp)) if err := os.MkdirAll(backupPath, 0755); err != nil { return "", fmt.Errorf("创建备份目录失败: %v", err) } return backupPath, nil } // 备份表结构到文件 func (m *BackupManager) backupTableSchema(table string, schemaFile string) error { schema, err := m.getTableSchema(table) if err != nil { return err } // 写入表结构到文件 if err := os.WriteFile(schemaFile, []byte(schema+";\n"), 0644); err != nil { return fmt.Errorf("写入表结构失败: %v", err) } return nil } // 清理旧备份 func (m *BackupManager) cleanOldBackups() error { if m.config.KeepDays <= 0 { return nil } cutoffTime := time.Now().AddDate(0, 0, -m.config.KeepDays) // 读取备份目录 entries, err := os.ReadDir(m.config.BackupDir) if err != nil { return fmt.Errorf("读取备份目录失败: %v", err) } for _, entry := range entries { if !entry.IsDir() { continue } // 检查目录名是否符合备份命名规范 if !strings.HasPrefix(entry.Name(), m.config.Database+"_") { continue } // 获取目录修改时间 info, err := entry.Info() if err != nil { log.Printf("获取目录信息失败: %v", err) continue } // 如果目录超过保留天数,删除它 if info.ModTime().Before(cutoffTime) { path := filepath.Join(m.config.BackupDir, entry.Name()) if err := os.RemoveAll(path); err != nil { log.Printf("删除旧备份失败: %v", err) } else { log.Printf("已删除旧备份: %s", path) } } } return nil } // 执行备份 func (m *BackupManager) performBackup() BackupResult { result := BackupResult{ StartTime: time.Now(), Success: false, } // 检查是否已经在运行备份 m.mutex.Lock() if m.running { m.mutex.Unlock() result.Message = "备份正在进行中" return result } m.running = true m.mutex.Unlock() defer func() { m.mutex.Lock() m.running = false m.mutex.Unlock() result.EndTime = time.Now() }() // 获取当前时间戳作为备份标识 timestamp := result.StartTime.Format("20060102150405") // 创建备份目录 backupDir, err := m.createBackupDir(timestamp) if err != nil { result.Message = fmt.Sprintf("创建备份目录失败: %v", err) return result } result.Filename = filepath.Base(backupDir) // 获取所有表 tables, err := m.getTables() if err != nil { result.Message = fmt.Sprintf("获取表列表失败: %v", err) return result } // 备份每个表 for _, table := range tables { log.Printf("开始备份表: %s", table) // 备份表结构 schemaFile := filepath.Join(backupDir, table+".sql") if err := m.backupTableSchema(table, schemaFile); err != nil { result.Message = fmt.Sprintf("备份表 %s 结构失败: %v", table, err) return result } // 备份表数据 dataFile := filepath.Join(backupDir, table+".csv") if err := m.backupTableData(table, dataFile); err != nil { result.Message = fmt.Sprintf("备份表 %s 数据失败: %v", table, err) return result } log.Printf("表 %s 备份完成", table) } // 清理旧备份 if err := m.cleanOldBackups(); err != nil { log.Printf("清理旧备份时出错: %v", err) } result.Success = true result.Message = fmt.Sprintf("数据库 %s 备份成功,存储在: %s", m.config.Database, backupDir) return result } // StartBackup 启动备份(异步或同步) func (m *BackupManager) StartBackup() (BackupResult, chan BackupResult) { if m.config.Async { resultChan := make(chan BackupResult, 1) go func() { result := m.performBackup() resultChan <- result close(resultChan) }() return BackupResult{Success: true, Message: "异步备份已启动"}, resultChan } result := m.performBackup() return result, nil } // Close 关闭数据库连接 func (m *BackupManager) Close() error { return m.db.Close() } func main() { // 解析命令行参数 host := flag.String("host", "localhost", "MySQL主机地址") port := flag.Int("port", 3306, "MySQL端口") user := flag.String("user", "root", "MySQL用户名") password := flag.String("password", "", "MySQL密码") database := flag.String("db", "", "要备份的数据库名") backupDir := flag.String("dir", "backups", "备份文件存储目录") keepDays := flag.Int("keep", 7, "备份保留天数") async := flag.Bool("async", false, "是否异步执行备份") flag.Parse() if *database == "" { log.Fatal("必须指定要备份的数据库名 (-db 参数)") } // 创建备份配置 config := BackupConfig{ Host: *host, Port: *port, User: *user, Password: *password, Database: *database, BackupDir: *backupDir, KeepDays: *keepDays, Async: *async, } // 创建备份管理器 manager, err := NewBackupManager(config) if err != nil { log.Fatalf("初始化备份管理器失败: %v", err) } defer manager.Close() // 启动备份 log.Println("开始数据库备份...") result, resultChan := manager.StartBackup() log.Println(result.Message) // 如果是异步备份,等待结果 if *async && resultChan != nil { log.Println("等待备份完成...") backupResult := <-resultChan if backupResult.Success { log.Printf("备份成功,耗时: %v", backupResult.EndTime.Sub(backupResult.StartTime)) log.Println(backupResult.Message) } else { log.Printf("备份失败: %s", backupResult.Message) } } }