refactor(db): change SQLite driver from github.com/jinzhu/gorm/dialects/sqlite to github.com/glebarez/go-sqlite (#1626)

* sqlite 驱动从 github.com/jinzhu/gorm/dialects/sqlite 改为 github.com/glebarez/go-sqlite,以移除对 cgo 的依赖

* // 兼容已有配置中的 "sqlite3" 配置项

* Update models/init.go: 修改变量名
This commit is contained in:
VigorFox
2023-02-08 09:53:41 +08:00
committed by GitHub
parent 6d1c44f21b
commit 9c58278e08
7 changed files with 343 additions and 28 deletions

288
models/dialect_sqlite.go Normal file
View File

@@ -0,0 +1,288 @@
package model
import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"time"
"github.com/jinzhu/gorm"
)
var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+")
// DefaultForeignKeyNamer contains the default foreign key name generator method
type DefaultForeignKeyNamer struct {
}
type commonDialect struct {
db gorm.SQLCommon
DefaultForeignKeyNamer
}
func (commonDialect) GetName() string {
return "common"
}
func (s *commonDialect) SetDB(db gorm.SQLCommon) {
s.db = db
}
func (commonDialect) BindVar(i int) string {
return "$$$" // ?
}
func (commonDialect) Quote(key string) string {
return fmt.Sprintf(`"%s"`, key)
}
func (s *commonDialect) fieldCanAutoIncrement(field *gorm.StructField) bool {
if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok {
return strings.ToLower(value) != "false"
}
return field.IsPrimaryKey
}
func (s *commonDialect) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "BOOLEAN"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
sqlType = "INTEGER AUTO_INCREMENT"
} else {
sqlType = "INTEGER"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
sqlType = "BIGINT AUTO_INCREMENT"
} else {
sqlType = "BIGINT"
}
case reflect.Float32, reflect.Float64:
sqlType = "FLOAT"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("VARCHAR(%d)", size)
} else {
sqlType = "VARCHAR(65532)"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "TIMESTAMP"
}
default:
if _, ok := dataValue.Interface().([]byte); ok {
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("BINARY(%d)", size)
} else {
sqlType = "BINARY(65532)"
}
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for commonDialect", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func currentDatabaseAndTable(dialect gorm.Dialect, tableName string) (string, string) {
if strings.Contains(tableName, ".") {
splitStrings := strings.SplitN(tableName, ".", 2)
return splitStrings[0], splitStrings[1]
}
return dialect.CurrentDatabase(), tableName
}
func (s commonDialect) HasIndex(tableName string, indexName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count)
return count > 0
}
func (s commonDialect) RemoveIndex(tableName string, indexName string) error {
_, err := s.db.Exec(fmt.Sprintf("DROP INDEX %v", indexName))
return err
}
func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bool {
return false
}
func (s commonDialect) HasTable(tableName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count)
return count > 0
}
func (s commonDialect) HasColumn(tableName string, columnName string) bool {
var count int
currentDatabase, tableName := currentDatabaseAndTable(&s, tableName)
s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count)
return count > 0
}
func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error {
_, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ))
return err
}
func (s commonDialect) CurrentDatabase() (name string) {
s.db.QueryRow("SELECT DATABASE()").Scan(&name)
return
}
func (commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string) {
if limit != nil {
if parsedLimit, err := strconv.ParseInt(fmt.Sprint(limit), 0, 0); err == nil && parsedLimit >= 0 {
sql += fmt.Sprintf(" LIMIT %d", parsedLimit)
}
}
if offset != nil {
if parsedOffset, err := strconv.ParseInt(fmt.Sprint(offset), 0, 0); err == nil && parsedOffset >= 0 {
sql += fmt.Sprintf(" OFFSET %d", parsedOffset)
}
}
return
}
func (commonDialect) SelectFromDummyTable() string {
return ""
}
func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string {
return ""
}
func (commonDialect) DefaultValueStr() string {
return "DEFAULT VALUES"
}
// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference
func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string {
keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_"))
keyName = keyNameRegex.ReplaceAllString(keyName, "_")
return keyName
}
// NormalizeIndexAndColumn returns argument's index name and column name without doing anything
func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) {
return indexName, columnName
}
// IsByteArrayOrSlice returns true of the reflected value is an array or slice
func IsByteArrayOrSlice(value reflect.Value) bool {
return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0))
}
type sqlite struct {
commonDialect
}
func init() {
gorm.RegisterDialect("sqlite", &sqlite{})
}
func (sqlite) GetName() string {
return "sqlite"
}
// Get Data Type for Sqlite Dialect
func (s *sqlite) DataTypeOf(field *gorm.StructField) string {
var dataValue, sqlType, size, additionalType = gorm.ParseFieldStructForDialect(field, s)
if sqlType == "" {
switch dataValue.Kind() {
case reflect.Bool:
sqlType = "bool"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "integer"
}
case reflect.Int64, reflect.Uint64:
if s.fieldCanAutoIncrement(field) {
field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT")
sqlType = "integer primary key autoincrement"
} else {
sqlType = "bigint"
}
case reflect.Float32, reflect.Float64:
sqlType = "real"
case reflect.String:
if size > 0 && size < 65532 {
sqlType = fmt.Sprintf("varchar(%d)", size)
} else {
sqlType = "text"
}
case reflect.Struct:
if _, ok := dataValue.Interface().(time.Time); ok {
sqlType = "datetime"
}
default:
if IsByteArrayOrSlice(dataValue) {
sqlType = "blob"
}
}
}
if sqlType == "" {
panic(fmt.Sprintf("invalid sql type %s (%s) for sqlite", dataValue.Type().Name(), dataValue.Kind().String()))
}
if strings.TrimSpace(additionalType) == "" {
return sqlType
}
return fmt.Sprintf("%v %v", sqlType, additionalType)
}
func (s sqlite) HasIndex(tableName string, indexName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND sql LIKE '%%INDEX %v ON%%'", indexName), tableName).Scan(&count)
return count > 0
}
func (s sqlite) HasTable(tableName string) bool {
var count int
s.db.QueryRow("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", tableName).Scan(&count)
return count > 0
}
func (s sqlite) HasColumn(tableName string, columnName string) bool {
var count int
s.db.QueryRow(fmt.Sprintf("SELECT count(*) FROM sqlite_master WHERE tbl_name = ? AND (sql LIKE '%%\"%v\" %%' OR sql LIKE '%%%v %%');\n", columnName, columnName), tableName).Scan(&count)
return count > 0
}
func (s sqlite) CurrentDatabase() (name string) {
var (
ifaces = make([]interface{}, 3)
pointers = make([]*string, 3)
i int
)
for i = 0; i < 3; i++ {
ifaces[i] = &pointers[i]
}
if err := s.db.QueryRow("PRAGMA database_list").Scan(ifaces...); err != nil {
return
}
if pointers[1] != nil {
name = *pointers[1]
}
return
}

View File

@@ -106,7 +106,7 @@ func TestFolder_GetChildFolder(t *testing.T) {
}
func TestGetRecursiveChildFolderSQLite(t *testing.T) {
conf.DatabaseConfig.Type = "sqlite3"
conf.DatabaseConfig.Type = "sqlite"
asserts := assert.New(t)
// 测试目录结构

View File

@@ -9,10 +9,10 @@ import (
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
_ "github.com/glebarez/go-sqlite"
_ "github.com/jinzhu/gorm/dialects/mssql"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/postgres"
_ "github.com/jinzhu/gorm/dialects/sqlite"
)
// DB 数据库链接单例
@@ -23,20 +23,26 @@ func Init() {
util.Log().Info("Initializing database connection...")
var (
db *gorm.DB
err error
db *gorm.DB
err error
confDBType string = conf.DatabaseConfig.Type
)
// 兼容已有配置中的 "sqlite3" 配置项
if confDBType == "sqlite3" {
confDBType = "sqlite"
}
if gin.Mode() == gin.TestMode {
// 测试模式下,使用内存数据库
db, err = gorm.Open("sqlite3", ":memory:")
db, err = gorm.Open("sqlite", ":memory:")
} else {
switch conf.DatabaseConfig.Type {
case "UNSET", "sqlite", "sqlite3":
// 未指定数据库或者明确指定为 sqlite 时,使用 SQLite3 数据库
db, err = gorm.Open("sqlite3", util.RelativePath(conf.DatabaseConfig.DBFile))
switch confDBType {
case "UNSET", "sqlite":
// 未指定数据库或者明确指定为 sqlite 时,使用 SQLite 数据库
db, err = gorm.Open("sqlite", util.RelativePath(conf.DatabaseConfig.DBFile))
case "postgres":
db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
db, err = gorm.Open(confDBType, fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
conf.DatabaseConfig.Host,
conf.DatabaseConfig.User,
conf.DatabaseConfig.Password,
@@ -53,14 +59,14 @@ func Init() {
conf.DatabaseConfig.Port)
}
db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("%s:%s@%s/%s?charset=%s&parseTime=True&loc=Local",
db, err = gorm.Open(confDBType, fmt.Sprintf("%s:%s@%s/%s?charset=%s&parseTime=True&loc=Local",
conf.DatabaseConfig.User,
conf.DatabaseConfig.Password,
host,
conf.DatabaseConfig.Name,
conf.DatabaseConfig.Charset))
default:
util.Log().Panic("Unsupported database type %q.", conf.DatabaseConfig.Type)
util.Log().Panic("Unsupported database type %q.", confDBType)
}
}
@@ -83,7 +89,7 @@ func Init() {
//设置连接池
db.DB().SetMaxIdleConns(50)
if conf.DatabaseConfig.Type == "sqlite" || conf.DatabaseConfig.Type == "sqlite3" || conf.DatabaseConfig.Type == "UNSET" {
if confDBType == "sqlite" || confDBType == "UNSET" {
db.DB().SetMaxOpenConns(1)
} else {
db.DB().SetMaxOpenConns(100)

View File

@@ -10,8 +10,8 @@ import (
func TestMigration(t *testing.T) {
asserts := assert.New(t)
conf.DatabaseConfig.Type = "sqlite3"
DB, _ = gorm.Open("sqlite3", ":memory:")
conf.DatabaseConfig.Type = "sqlite"
DB, _ = gorm.Open("sqlite", ":memory:")
asserts.NotPanics(func() {
migration()