db.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package database
  2. import (
  3. "database/sql"
  4. "embed"
  5. "fmt"
  6. "log/slog"
  7. "github.com/golang-migrate/migrate/v4"
  8. "github.com/golang-migrate/migrate/v4/database/sqlite3"
  9. "github.com/golang-migrate/migrate/v4/source/iofs"
  10. _ "github.com/mattn/go-sqlite3"
  11. )
  12. //go:embed migrations/*.sql
  13. var migrations embed.FS
  14. func Open(dbPath string) (*sql.DB, error) {
  15. connStr := dbPath + "?_journal_mode=WAL&_foreign_keys=1"
  16. migrationDb, err := sql.Open("sqlite3", connStr)
  17. if err != nil {
  18. return nil, fmt.Errorf("open db for migrations: %w", err)
  19. }
  20. if err := migrationDb.Ping(); err != nil {
  21. migrationDb.Close()
  22. return nil, fmt.Errorf("ping db: %w", err)
  23. }
  24. if err := runMigrations(migrationDb); err != nil {
  25. migrationDb.Close()
  26. return nil, fmt.Errorf("run migrations: %w", err)
  27. }
  28. migrationDb.Close()
  29. db, err := sql.Open("sqlite3", connStr)
  30. if err != nil {
  31. return nil, fmt.Errorf("open db: %w", err)
  32. }
  33. if err := db.Ping(); err != nil {
  34. db.Close()
  35. return nil, fmt.Errorf("ping db: %w", err)
  36. }
  37. return db, nil
  38. }
  39. func runMigrations(db *sql.DB) error {
  40. sourceDriver, err := iofs.New(migrations, "migrations")
  41. if err != nil {
  42. return fmt.Errorf("create migration source: %w", err)
  43. }
  44. driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
  45. if err != nil {
  46. return fmt.Errorf("create database driver: %w", err)
  47. }
  48. m, err := migrate.NewWithInstance("iofs", sourceDriver, "sqlite3", driver)
  49. if err != nil {
  50. return fmt.Errorf("create migrate instance: %w", err)
  51. }
  52. defer m.Close()
  53. currentVersion, _, err := m.Version()
  54. if err != nil && err != migrate.ErrNilVersion {
  55. return fmt.Errorf("get current migration version: %w", err)
  56. }
  57. if err == migrate.ErrNilVersion {
  58. slog.Info("running database migrations", "currentVersion", "none")
  59. } else {
  60. slog.Info("running database migrations", "currentVersion", currentVersion)
  61. }
  62. upErr := m.Up()
  63. if upErr != nil && upErr != migrate.ErrNoChange {
  64. return fmt.Errorf("run migrations: %w", upErr)
  65. }
  66. newVersion, dirty, versionErr := m.Version()
  67. if versionErr != nil && versionErr != migrate.ErrNilVersion {
  68. return fmt.Errorf("get migration version after execution: %w", versionErr)
  69. }
  70. if upErr == migrate.ErrNoChange {
  71. slog.Info("no pending migrations", "currentVersion", currentVersion)
  72. } else if dirty {
  73. slog.Warn("migration completed but database is in dirty state", "version", newVersion)
  74. } else {
  75. if err == migrate.ErrNilVersion {
  76. slog.Info("migrations completed", "appliedVersion", newVersion)
  77. } else if currentVersion != newVersion {
  78. slog.Info("migrations completed", "fromVersion", currentVersion, "toVersion", newVersion)
  79. } else {
  80. slog.Info("migrations completed", "version", newVersion)
  81. }
  82. }
  83. return nil
  84. }