48 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			48 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package db
 | |
| 
 | |
| import (
 | |
| 	"database/sql"
 | |
| 	"embed"
 | |
| 	"fmt"
 | |
| 	"strings"
 | |
| 
 | |
| 	_ "github.com/go-sql-driver/mysql"
 | |
| 	_ "modernc.org/sqlite"
 | |
| 
 | |
| 	"mind/internal/config"
 | |
| )
 | |
| 
 | |
| //go:embed migrations/*.sql
 | |
| var migrationFS embed.FS
 | |
| 
 | |
| func OpenAndMigrate(cfg config.Config) (*sql.DB, error) {
 | |
| 	dsn := cfg.DSN
 | |
| 	drv := strings.ToLower(cfg.Driver)
 | |
| 	if drv != "mysql" && drv != "sqlite" {
 | |
| 		drv = "sqlite"
 | |
| 	}
 | |
| 	db, err := sql.Open(drv, dsn)
 | |
| 	if err != nil { return nil, err }
 | |
| 	if err := db.Ping(); err != nil { return nil, err }
 | |
| 	if err := runMigrations(db, drv); err != nil { return nil, err }
 | |
| 	return db, nil
 | |
| }
 | |
| 
 | |
| func runMigrations(db *sql.DB, driver string) error {
 | |
| 	files, err := migrationFS.ReadDir("migrations")
 | |
| 	if err != nil { return err }
 | |
| 	for _, f := range files {
 | |
| 		b, err := migrationFS.ReadFile("migrations/"+f.Name())
 | |
| 		if err != nil { return err }
 | |
| 		sqlText := string(b)
 | |
| 		if driver == "sqlite" {
 | |
| 			// very minor compatibility tweak: drop ENUM
 | |
| 			sqlText = strings.ReplaceAll(sqlText, "ENUM('user','assistant')", "TEXT")
 | |
| 		}
 | |
| 		if _, err := db.Exec(sqlText); err != nil {
 | |
| 			return fmt.Errorf("migration %s: %w", f.Name(), err)
 | |
| 		}
 | |
| 	}
 | |
| 	return nil
 | |
| }
 |