48 lines
1.1 KiB
Go
48 lines
1.1 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
"strings"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
_ "modernc.org/sqlite"
|
|
|
|
"mind/internal/config"
|
|
)
|
|
|
|
//go:embed migrations/*.sql
|
|
var migrationFS embed.FS
|
|
|
|
func OpenAndMigrate(cfg config.Config) (*sql.DB, error) {
|
|
dsn := cfg.DSN
|
|
drv := strings.ToLower(cfg.Driver)
|
|
if drv != "mysql" && drv != "sqlite" {
|
|
drv = "sqlite"
|
|
}
|
|
db, err := sql.Open(drv, dsn)
|
|
if err != nil { return nil, err }
|
|
if err := db.Ping(); err != nil { return nil, err }
|
|
if err := runMigrations(db, drv); err != nil { return nil, err }
|
|
return db, nil
|
|
}
|
|
|
|
func runMigrations(db *sql.DB, driver string) error {
|
|
files, err := migrationFS.ReadDir("migrations")
|
|
if err != nil { return err }
|
|
for _, f := range files {
|
|
b, err := migrationFS.ReadFile("migrations/"+f.Name())
|
|
if err != nil { return err }
|
|
sqlText := string(b)
|
|
if driver == "sqlite" {
|
|
// very minor compatibility tweak: drop ENUM
|
|
sqlText = strings.ReplaceAll(sqlText, "ENUM('user','assistant')", "TEXT")
|
|
}
|
|
if _, err := db.Exec(sqlText); err != nil {
|
|
return fmt.Errorf("migration %s: %w", f.Name(), err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|