mirror of https://github.com/grafana/grafana.git
Chore: add/update sqlstore-related helper functions (#77408)
* add/update sqlstore-related helper functions * add documentation & tests for InsertQuery and UpdateQuery, make generated SQL deterministic by sorting columns * remove old log line
This commit is contained in:
parent
6b729389b5
commit
67b2972052
|
|
@ -5,6 +5,7 @@ import (
|
|||
"os"
|
||||
|
||||
"xorm.io/core"
|
||||
"xorm.io/xorm"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
|
|
@ -29,6 +30,8 @@ type DB interface {
|
|||
GetDialect() migrator.Dialect
|
||||
// GetDBType returns the name of the database type available to the runtime.
|
||||
GetDBType() core.DbType
|
||||
// GetEngine returns the underlying xorm engine.
|
||||
GetEngine() *xorm.Engine
|
||||
// GetSqlxSession is an experimental extension to use sqlx instead of xorm to
|
||||
// communicate with the database.
|
||||
GetSqlxSession() *session.SessionDB
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
|
||||
"xorm.io/core"
|
||||
"xorm.io/xorm"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore"
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
|
|
@ -43,6 +44,10 @@ func (f *FakeDB) GetDialect() migrator.Dialect {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (f *FakeDB) GetEngine() *xorm.Engine {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FakeDB) GetSqlxSession() *session.SessionDB {
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,7 +100,6 @@ func (s *XormLogger) SetLevel(l core.LogLevel) {
|
|||
|
||||
// ShowSQL implement core.ILogger
|
||||
func (s *XormLogger) ShowSQL(show ...bool) {
|
||||
s.grafanaLog.Error("ShowSQL", "show", "show")
|
||||
if len(show) == 0 {
|
||||
s.showSQL = true
|
||||
return
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
|
|
@ -73,6 +74,14 @@ type Dialect interface {
|
|||
Unlock(LockCfg) error
|
||||
|
||||
GetDBName(string) (string, error)
|
||||
|
||||
// InsertQuery accepts a table name and a map of column names to values to insert.
|
||||
// It returns a query string and a slice of parameters that can be executed against the database.
|
||||
InsertQuery(tableName string, row map[string]any) (string, []any, error)
|
||||
// UpdateQuery accepts a table name, a map of column names to values to update, and a map of
|
||||
// column names to values to use in the where clause.
|
||||
// It returns a query string and a slice of parameters that can be executed against the database.
|
||||
UpdateQuery(tableName string, row map[string]any, where map[string]any) (string, []any, error)
|
||||
}
|
||||
|
||||
type LockCfg struct {
|
||||
|
|
@ -344,3 +353,71 @@ func (b *BaseDialect) OrderBy(order string) string {
|
|||
func (b *BaseDialect) GetDBName(_ string) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (b *BaseDialect) InsertQuery(tableName string, row map[string]any) (string, []any, error) {
|
||||
if len(row) < 1 {
|
||||
return "", nil, fmt.Errorf("no columns provided")
|
||||
}
|
||||
|
||||
// allocate slices
|
||||
cols := make([]string, 0, len(row))
|
||||
vals := make([]any, 0, len(row))
|
||||
keys := make([]string, 0, len(row))
|
||||
|
||||
// create sorted list of columns
|
||||
for col := range row {
|
||||
keys = append(keys, col)
|
||||
}
|
||||
slices.Sort[string](keys)
|
||||
|
||||
// build query and values
|
||||
for _, col := range keys {
|
||||
cols = append(cols, b.dialect.Quote(col))
|
||||
vals = append(vals, row[col])
|
||||
}
|
||||
|
||||
return fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", b.dialect.Quote(tableName), strings.Join(cols, ", "), strings.Repeat("?, ", len(row)-1)+"?"), vals, nil
|
||||
}
|
||||
|
||||
func (b *BaseDialect) UpdateQuery(tableName string, row map[string]any, where map[string]any) (string, []any, error) {
|
||||
if len(row) < 1 {
|
||||
return "", nil, fmt.Errorf("no columns provided")
|
||||
}
|
||||
|
||||
if len(where) < 1 {
|
||||
return "", nil, fmt.Errorf("no where clause provided")
|
||||
}
|
||||
|
||||
// allocate slices
|
||||
cols := make([]string, 0, len(row))
|
||||
whereCols := make([]string, 0, len(where))
|
||||
vals := make([]any, 0, len(row)+len(where))
|
||||
keys := make([]string, 0, len(row))
|
||||
|
||||
// create sorted list of columns to update
|
||||
for col := range row {
|
||||
keys = append(keys, col)
|
||||
}
|
||||
slices.Sort[string](keys)
|
||||
|
||||
// build update query and values
|
||||
for _, col := range keys {
|
||||
cols = append(cols, b.dialect.Quote(col)+"=?")
|
||||
vals = append(vals, row[col])
|
||||
}
|
||||
|
||||
// create sorted list of columns for where clause
|
||||
keys = make([]string, 0, len(where))
|
||||
for col := range where {
|
||||
keys = append(keys, col)
|
||||
}
|
||||
slices.Sort[string](keys)
|
||||
|
||||
// build where clause and values
|
||||
for _, col := range keys {
|
||||
whereCols = append(whereCols, b.dialect.Quote(col)+"=?")
|
||||
vals = append(vals, where[col])
|
||||
}
|
||||
|
||||
return fmt.Sprintf("UPDATE %s SET %s WHERE %s", b.dialect.Quote(tableName), strings.Join(cols, ", "), strings.Join(whereCols, " AND ")), vals, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
package migrator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInsertQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableName string
|
||||
values map[string]any
|
||||
expectedErr bool
|
||||
expectedPostgresQuery string
|
||||
expectedPostgresArgs []any
|
||||
expectedMySQLQuery string
|
||||
expectedMySQLArgs []any
|
||||
expectedSQLiteQuery string
|
||||
expectedSQLiteArgs []any
|
||||
}{
|
||||
{
|
||||
"insert one",
|
||||
"some_table",
|
||||
map[string]any{"col1": "val1", "col2": "val2", "col3": "val3"},
|
||||
false,
|
||||
"INSERT INTO \"some_table\" (\"col1\", \"col2\", \"col3\") VALUES (?, ?, ?)",
|
||||
[]any{"val1", "val2", "val3"},
|
||||
"INSERT INTO `some_table` (`col1`, `col2`, `col3`) VALUES (?, ?, ?)",
|
||||
[]any{"val1", "val2", "val3"},
|
||||
"INSERT INTO `some_table` (`col1`, `col2`, `col3`) VALUES (?, ?, ?)",
|
||||
[]any{"val1", "val2", "val3"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var db Dialect
|
||||
db = NewPostgresDialect()
|
||||
q, args, err := db.InsertQuery(tc.tableName, tc.values)
|
||||
|
||||
require.True(t, (err != nil) == tc.expectedErr)
|
||||
require.Equal(t, tc.expectedPostgresQuery, q, "Postgres query incorrect")
|
||||
require.Equal(t, tc.expectedPostgresArgs, args, "Postgres args incorrect")
|
||||
|
||||
db = NewMysqlDialect()
|
||||
q, args, err = db.InsertQuery(tc.tableName, tc.values)
|
||||
|
||||
require.True(t, (err != nil) == tc.expectedErr)
|
||||
require.Equal(t, tc.expectedMySQLQuery, q, "MySQL query incorrect")
|
||||
require.Equal(t, tc.expectedMySQLArgs, args, "MySQL args incorrect")
|
||||
|
||||
db = NewSQLite3Dialect()
|
||||
q, args, err = db.InsertQuery(tc.tableName, tc.values)
|
||||
|
||||
require.True(t, (err != nil) == tc.expectedErr)
|
||||
require.Equal(t, tc.expectedSQLiteQuery, q, "SQLite query incorrect")
|
||||
require.Equal(t, tc.expectedSQLiteArgs, args, "SQLite args incorrect")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableName string
|
||||
values map[string]any
|
||||
where map[string]any
|
||||
expectedErr bool
|
||||
expectedPostgresQuery string
|
||||
expectedPostgresArgs []any
|
||||
expectedMySQLQuery string
|
||||
expectedMySQLArgs []any
|
||||
expectedSQLiteQuery string
|
||||
expectedSQLiteArgs []any
|
||||
}{
|
||||
{
|
||||
"insert one",
|
||||
"some_table",
|
||||
map[string]any{"col1": "val1", "col2": "val2", "col3": "val3"},
|
||||
map[string]any{"key1": 10},
|
||||
false,
|
||||
"UPDATE \"some_table\" SET \"col1\"=?, \"col2\"=?, \"col3\"=? WHERE \"key1\"=?",
|
||||
[]any{"val1", "val2", "val3", 10},
|
||||
"UPDATE `some_table` SET `col1`=?, `col2`=?, `col3`=? WHERE `key1`=?",
|
||||
[]any{"val1", "val2", "val3", 10},
|
||||
"UPDATE `some_table` SET `col1`=?, `col2`=?, `col3`=? WHERE `key1`=?",
|
||||
[]any{"val1", "val2", "val3", 10},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var db Dialect
|
||||
db = NewPostgresDialect()
|
||||
q, args, err := db.UpdateQuery(tc.tableName, tc.values, tc.where)
|
||||
|
||||
require.True(t, (err != nil) == tc.expectedErr)
|
||||
require.Equal(t, tc.expectedPostgresQuery, q, "Postgres query incorrect")
|
||||
require.Equal(t, tc.expectedPostgresArgs, args, "Postgres args incorrect")
|
||||
|
||||
db = NewMysqlDialect()
|
||||
q, args, err = db.UpdateQuery(tc.tableName, tc.values, tc.where)
|
||||
|
||||
require.True(t, (err != nil) == tc.expectedErr)
|
||||
require.Equal(t, tc.expectedMySQLQuery, q, "MySQL query incorrect")
|
||||
require.Equal(t, tc.expectedMySQLArgs, args, "MySQL args incorrect")
|
||||
|
||||
db = NewSQLite3Dialect()
|
||||
q, args, err = db.UpdateQuery(tc.tableName, tc.values, tc.where)
|
||||
|
||||
require.True(t, (err != nil) == tc.expectedErr)
|
||||
require.Equal(t, tc.expectedSQLiteQuery, q, "SQLite query incorrect")
|
||||
require.Equal(t, tc.expectedSQLiteArgs, args, "SQLite args incorrect")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -59,7 +59,7 @@ func NewScopedMigrator(engine *xorm.Engine, cfg *setting.Cfg, scope string) *Mig
|
|||
mg.Logger = log.New("migrator")
|
||||
} else {
|
||||
mg.tableName = scope + "_migration_log"
|
||||
mg.Logger = log.New(scope + " migrator")
|
||||
mg.Logger = log.New(scope + "-migrator")
|
||||
}
|
||||
return mg
|
||||
}
|
||||
|
|
|
|||
|
|
@ -69,6 +69,9 @@ func (db *MySQLDialect) SQLType(c *Column) string {
|
|||
c.Length = 64
|
||||
case DB_NVarchar:
|
||||
res = DB_Varchar
|
||||
case DB_Uuid:
|
||||
res = DB_Char
|
||||
c.Length = 36
|
||||
default:
|
||||
res = c.Type
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func (gs *SessionDB) NamedExec(ctx context.Context, query string, arg any) (sql.
|
|||
return gs.sqlxdb.NamedExecContext(ctx, gs.sqlxdb.Rebind(query), arg)
|
||||
}
|
||||
|
||||
func (gs *SessionDB) driverName() string {
|
||||
func (gs *SessionDB) DriverName() string {
|
||||
return gs.sqlxdb.DriverName()
|
||||
}
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ func (gs *SessionDB) WithTransaction(ctx context.Context, callback func(*Session
|
|||
}
|
||||
|
||||
func (gs *SessionDB) ExecWithReturningId(ctx context.Context, query string, args ...any) (int64, error) {
|
||||
return execWithReturningId(ctx, gs.driverName(), query, gs, args...)
|
||||
return execWithReturningId(ctx, gs.DriverName(), query, gs, args...)
|
||||
}
|
||||
|
||||
type SessionTx struct {
|
||||
|
|
@ -125,3 +125,10 @@ func execWithReturningId(ctx context.Context, driverName string, query string, s
|
|||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
type SessionQuerier interface {
|
||||
Query(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
var _ SessionQuerier = &SessionDB{}
|
||||
var _ SessionQuerier = &SessionTx{}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,13 @@
|
|||
package sqlstash
|
||||
|
||||
import "strings"
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/grafana/grafana/pkg/services/sqlstore/migrator"
|
||||
)
|
||||
|
||||
type selectQuery struct {
|
||||
dialect migrator.Dialect
|
||||
fields []string // SELECT xyz
|
||||
from string // FROM object
|
||||
limit int64
|
||||
|
|
@ -12,21 +17,27 @@ type selectQuery struct {
|
|||
args []any
|
||||
}
|
||||
|
||||
func (q *selectQuery) addWhere(f string, val any) {
|
||||
q.args = append(q.args, val)
|
||||
q.where = append(q.where, f+"=?")
|
||||
func (q *selectQuery) addWhere(f string, val ...any) {
|
||||
q.args = append(q.args, val...)
|
||||
// if the field contains a question mark, we assume it's a raw where clause
|
||||
if strings.Contains(f, "?") {
|
||||
q.where = append(q.where, f)
|
||||
// otherwise we assume it's a field name
|
||||
} else {
|
||||
q.where = append(q.where, q.dialect.Quote(f)+"=?")
|
||||
}
|
||||
}
|
||||
|
||||
func (q *selectQuery) addWhereInSubquery(f string, subquery string, subqueryArgs []any) {
|
||||
q.args = append(q.args, subqueryArgs...)
|
||||
q.where = append(q.where, f+" IN ("+subquery+")")
|
||||
q.where = append(q.where, q.dialect.Quote(f)+" IN ("+subquery+")")
|
||||
}
|
||||
|
||||
func (q *selectQuery) addWhereIn(f string, vals []string) {
|
||||
count := len(vals)
|
||||
if count > 1 {
|
||||
sb := strings.Builder{}
|
||||
sb.WriteString(f)
|
||||
sb.WriteString(q.dialect.Quote(f))
|
||||
sb.WriteString(" IN (")
|
||||
for i := 0; i < count; i++ {
|
||||
if i > 0 {
|
||||
|
|
@ -46,7 +57,11 @@ func (q *selectQuery) toQuery() (string, []any) {
|
|||
args := q.args
|
||||
sb := strings.Builder{}
|
||||
sb.WriteString("SELECT ")
|
||||
sb.WriteString(strings.Join(q.fields, ","))
|
||||
quotedFields := make([]string, len(q.fields))
|
||||
for i, f := range q.fields {
|
||||
quotedFields[i] = q.dialect.Quote(f)
|
||||
}
|
||||
sb.WriteString(strings.Join(quotedFields, ","))
|
||||
sb.WriteString(" FROM ")
|
||||
sb.WriteString(q.from)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue