diff --git a/pkg/registry/apis/secret/contracts/encryption.go b/pkg/registry/apis/secret/contracts/encryption.go index 06613d23d3c..582ecc1b3a4 100644 --- a/pkg/registry/apis/secret/contracts/encryption.go +++ b/pkg/registry/apis/secret/contracts/encryption.go @@ -46,7 +46,7 @@ type GlobalEncryptedValueStorage interface { } type EncryptedValueMigrationExecutor interface { - Execute(ctx context.Context) error + Execute(ctx context.Context) (int, error) } type ConsolidationService interface { diff --git a/pkg/registry/apis/secret/encryption/manager/manager.go b/pkg/registry/apis/secret/encryption/manager/manager.go index 4c9c8e59c92..cf1f2173c96 100644 --- a/pkg/registry/apis/secret/encryption/manager/manager.go +++ b/pkg/registry/apis/secret/encryption/manager/manager.go @@ -21,10 +21,6 @@ import ( "github.com/grafana/grafana/pkg/util" ) -const ( - keyIdDelimiter = '#' -) - type EncryptionManager struct { tracer trace.Tracer store contracts.DataKeyStorage @@ -286,20 +282,6 @@ func (s *EncryptionManager) Decrypt(ctx context.Context, namespace string, paylo return nil, err } - // payload = payload[1:] - // endOfKey := bytes.Index(payload, []byte{keyIdDelimiter}) - // if endOfKey == -1 { - // err = fmt.Errorf("could not find valid key id in encrypted payload") - // return nil, err - // } - // b64Key := payload[:endOfKey] - // payload = payload[endOfKey+1:] - // keyId := make([]byte, base64.RawStdEncoding.DecodedLen(len(b64Key))) - // _, err = base64.RawStdEncoding.Decode(keyId, b64Key) - // if err != nil { - // return nil, err - // } - dataKey, err := s.dataKeyById(ctx, namespace, payload.DataKeyID) if err != nil { s.log.FromContext(ctx).Error("Failed to lookup data key by id", "id", payload.DataKeyID, "error", err) diff --git a/pkg/registry/apis/secret/secretkeeper/secretkeeper_test.go b/pkg/registry/apis/secret/secretkeeper/secretkeeper_test.go index 84baaf207b0..c73b80974e3 100644 --- a/pkg/registry/apis/secret/secretkeeper/secretkeeper_test.go +++ b/pkg/registry/apis/secret/secretkeeper/secretkeeper_test.go @@ -12,6 +12,7 @@ import ( osskmsproviders "github.com/grafana/grafana/pkg/registry/apis/secret/encryption/kmsproviders" "github.com/grafana/grafana/pkg/registry/apis/secret/encryption/manager" "github.com/grafana/grafana/pkg/registry/apis/secret/secretkeeper/sqlkeeper" + "github.com/grafana/grafana/pkg/registry/apis/secret/testutils" "github.com/grafana/grafana/pkg/services/sqlstore" "github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/storage/secret/database" @@ -65,7 +66,8 @@ func setupTestService(t *testing.T, cfg *setting.Cfg) (*OSSKeeperService, error) require.NoError(t, err) // Initialize the keeper service - keeperService, err := ProvideService(tracer, encValueStore, encryptionManager, nil) + keeperService, err := ProvideService(tracer, encValueStore, encryptionManager, &testutils.NoopMigrationExecutor{}, nil) + require.NoError(t, err) return keeperService, err } diff --git a/pkg/registry/apis/secret/secretkeeper/sqlkeeper/keeper.go b/pkg/registry/apis/secret/secretkeeper/sqlkeeper/keeper.go index cdeb686abe0..17d32b9cbc1 100644 --- a/pkg/registry/apis/secret/secretkeeper/sqlkeeper/keeper.go +++ b/pkg/registry/apis/secret/secretkeeper/sqlkeeper/keeper.go @@ -30,7 +30,7 @@ func NewSQLKeeper( reg prometheus.Registerer, ) (*SQLKeeper, error) { // Run the encrypted value store migration before anything else, otherwise operations may fail - err := migrationExecutor.Execute(context.Background()) + _, err := migrationExecutor.Execute(context.Background()) if err != nil { return nil, fmt.Errorf("failed to execute encrypted value store migration: %w", err) } diff --git a/pkg/registry/apis/secret/testutils/testutils.go b/pkg/registry/apis/secret/testutils/testutils.go index 66b5360e9db..2d1c94c2ea3 100644 --- a/pkg/registry/apis/secret/testutils/testutils.go +++ b/pkg/registry/apis/secret/testutils/testutils.go @@ -126,8 +126,13 @@ func Setup(t *testing.T, opts ...func(*SetupConfig)) Sut { globalEncryptedValueStorage, err := encryptionstorage.ProvideGlobalEncryptedValueStorage(database, tracer) require.NoError(t, err) - // TODO create a migration executor - sqlKeeper, err := sqlkeeper.NewSQLKeeper(tracer, encryptionManager, encryptedValueStorage, nil, nil) + // Initialize a noop migration executor for the sql keeper so it doesn't interfere with initialization + noopMigrationExecutor := &NoopMigrationExecutor{} + sqlKeeper, err := sqlkeeper.NewSQLKeeper(tracer, encryptionManager, encryptedValueStorage, noopMigrationExecutor, nil) + require.NoError(t, err) + + // Initialize a real migration executor for test + realMigrationExecutor, err := encryptionstorage.ProvideEncryptedValueMigrationExecutor(database, tracer, encryptedValueStorage, globalEncryptedValueStorage) require.NoError(t, err) var keeperService contracts.KeeperService = newKeeperServiceWrapper(sqlKeeper) @@ -160,39 +165,41 @@ func Setup(t *testing.T, opts ...func(*SetupConfig)) Sut { keeperService) return Sut{ - SecureValueService: secureValueService, - SecureValueMetadataStorage: secureValueMetadataStorage, - DecryptStorage: decryptStorage, - DecryptService: decryptService, - EncryptedValueStorage: encryptedValueStorage, - GlobalEncryptedValueStorage: globalEncryptedValueStorage, - SQLKeeper: sqlKeeper, - Database: database, - AccessClient: accessClient, - ConsolidationService: consolidationService, - EncryptionManager: encryptionManager, - GlobalDataKeyStore: globalDataKeyStore, - GarbageCollectionWorker: garbageCollectionWorker, - Clock: clock, - KeeperService: keeperService, - KeeperMetadataStorage: keeperMetadataStorage, + SecureValueService: secureValueService, + SecureValueMetadataStorage: secureValueMetadataStorage, + DecryptStorage: decryptStorage, + DecryptService: decryptService, + EncryptedValueStorage: encryptedValueStorage, + GlobalEncryptedValueStorage: globalEncryptedValueStorage, + EncryptedValueMigrationExecutor: realMigrationExecutor, + SQLKeeper: sqlKeeper, + Database: database, + AccessClient: accessClient, + ConsolidationService: consolidationService, + EncryptionManager: encryptionManager, + GlobalDataKeyStore: globalDataKeyStore, + GarbageCollectionWorker: garbageCollectionWorker, + Clock: clock, + KeeperService: keeperService, + KeeperMetadataStorage: keeperMetadataStorage, } } type Sut struct { - SecureValueService contracts.SecureValueService - SecureValueMetadataStorage contracts.SecureValueMetadataStorage - DecryptStorage contracts.DecryptStorage - DecryptService decryptcontracts.DecryptService - EncryptedValueStorage contracts.EncryptedValueStorage - GlobalEncryptedValueStorage contracts.GlobalEncryptedValueStorage - SQLKeeper *sqlkeeper.SQLKeeper - Database *database.Database - AccessClient types.AccessClient - ConsolidationService contracts.ConsolidationService - EncryptionManager contracts.EncryptionManager - GlobalDataKeyStore contracts.GlobalDataKeyStorage - GarbageCollectionWorker *garbagecollectionworker.Worker + SecureValueService contracts.SecureValueService + SecureValueMetadataStorage contracts.SecureValueMetadataStorage + DecryptStorage contracts.DecryptStorage + DecryptService decryptcontracts.DecryptService + EncryptedValueStorage contracts.EncryptedValueStorage + GlobalEncryptedValueStorage contracts.GlobalEncryptedValueStorage + EncryptedValueMigrationExecutor contracts.EncryptedValueMigrationExecutor + SQLKeeper *sqlkeeper.SQLKeeper + Database *database.Database + AccessClient types.AccessClient + ConsolidationService contracts.ConsolidationService + EncryptionManager contracts.EncryptionManager + GlobalDataKeyStore contracts.GlobalDataKeyStorage + GarbageCollectionWorker *garbagecollectionworker.Worker // The fake clock passed to implementations to make testing easier Clock *FakeClock KeeperService contracts.KeeperService @@ -368,3 +375,10 @@ func (c *FakeClock) Now() time.Time { func (c *FakeClock) AdvanceBy(duration time.Duration) { c.Current = c.Current.Add(duration) } + +type NoopMigrationExecutor struct { +} + +func (e *NoopMigrationExecutor) Execute(ctx context.Context) (int, error) { + return 0, nil +} diff --git a/pkg/storage/secret/encryption/encrypted_value_store.go b/pkg/storage/secret/encryption/encrypted_value_store.go index 80ac0b09403..4b19754221d 100644 --- a/pkg/storage/secret/encryption/encrypted_value_store.go +++ b/pkg/storage/secret/encryption/encrypted_value_store.go @@ -1,7 +1,9 @@ package encryption import ( + "bytes" "context" + "encoding/base64" "errors" "fmt" "time" @@ -343,26 +345,75 @@ func (s *globalEncryptedValStorage) CountAll(ctx context.Context, untilTime *int } type encryptedValMigrationExecutor struct { - db contracts.Database - dialect sqltemplate.Dialect - tracer trace.Tracer + db contracts.Database + dialect sqltemplate.Dialect + tracer trace.Tracer + encryptedValueStore contracts.EncryptedValueStorage + globalStore contracts.GlobalEncryptedValueStorage } func ProvideEncryptedValueMigrationExecutor( db contracts.Database, tracer trace.Tracer, + encryptedValueStore contracts.EncryptedValueStorage, + globalStore contracts.GlobalEncryptedValueStorage, ) (contracts.EncryptedValueMigrationExecutor, error) { return &encryptedValMigrationExecutor{ - db: db, - dialect: sqltemplate.DialectForDriver(db.DriverName()), - tracer: tracer, + db: db, + dialect: sqltemplate.DialectForDriver(db.DriverName()), + tracer: tracer, + encryptedValueStore: encryptedValueStore, + globalStore: globalStore, }, nil } -func (s *encryptedValMigrationExecutor) Execute(ctx context.Context) error { +func (s *encryptedValMigrationExecutor) Execute(ctx context.Context) (int, error) { ctx, span := s.tracer.Start(ctx, "EncryptedValueMigrationExecutor.Execute") defer span.End() - panic("not implemented") - return nil + // 1. Retrieve all encrypted values + encryptedValues, err := s.globalStore.ListAll(ctx, contracts.ListOpts{}, nil) + if err != nil { + return 0, fmt.Errorf("listing all encrypted values: %w", err) + } + + // This doesn't need to be done in a single transaction because there's no risk to successful rows if other rows fail + rowsAffected := 0 + for _, encryptedValue := range encryptedValues { + // 2. If the value already has the data key id broken out, skip it + if encryptedValue.DataKeyID != "" { + continue + } + + // 3. Split the data key id and the encrypted data out from the encoded payload + payload := encryptedValue.EncryptedData + const keyIdDelimiter = '#' + payload = payload[1:] + endOfKey := bytes.Index(payload, []byte{keyIdDelimiter}) + if endOfKey == -1 { + return 0, fmt.Errorf("could not find valid key id in encrypted payload with namespace %s and name %s and version %d", encryptedValue.Namespace, encryptedValue.Name, encryptedValue.Version) + } + b64Key := payload[:endOfKey] + encryptedData := payload[endOfKey+1:] + if len(encryptedData) == 0 { + return 0, fmt.Errorf("encrypted data is empty with namespace %s and name %s and version %d", encryptedValue.Namespace, encryptedValue.Name, encryptedValue.Version) + } + keyId := make([]byte, base64.RawStdEncoding.DecodedLen(len(b64Key))) + _, err := base64.RawStdEncoding.Decode(keyId, b64Key) + if err != nil { + return 0, fmt.Errorf("decoding key id with namespace %s and name %s and version %d: %w", encryptedValue.Namespace, encryptedValue.Name, encryptedValue.Version, err) + } + + // 4. Update the encrypted value with the data key id and the encrypted data + err = s.encryptedValueStore.Update(ctx, encryptedValue.Namespace, encryptedValue.Name, encryptedValue.Version, contracts.EncryptedPayload{ + DataKeyID: string(keyId), + EncryptedData: encryptedData, + }) + if err != nil { + return 0, fmt.Errorf("updating encrypted value with namespace %s and name %s and version %d: %w", encryptedValue.Namespace, encryptedValue.Name, encryptedValue.Version, err) + } + rowsAffected++ + } + + return rowsAffected, nil } diff --git a/pkg/storage/secret/encryption/encrypted_value_store_test.go b/pkg/storage/secret/encryption/encrypted_value_store_test.go index 62567795244..739575a9ebc 100644 --- a/pkg/storage/secret/encryption/encrypted_value_store_test.go +++ b/pkg/storage/secret/encryption/encrypted_value_store_test.go @@ -2,15 +2,23 @@ package encryption_test import ( "bytes" + "encoding/base64" "errors" + "fmt" "slices" "testing" + "text/template" "time" + "github.com/grafana/grafana/pkg/infra/usagestats" "github.com/grafana/grafana/pkg/registry/apis/secret/contracts" + "github.com/grafana/grafana/pkg/registry/apis/secret/encryption/cipher" + cipherService "github.com/grafana/grafana/pkg/registry/apis/secret/encryption/cipher/service" "github.com/grafana/grafana/pkg/registry/apis/secret/testutils" "github.com/grafana/grafana/pkg/storage/secret/encryption" + "github.com/grafana/grafana/pkg/storage/unified/sql/sqltemplate" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" "pgregory.net/rapid" ) @@ -233,6 +241,135 @@ func TestEncryptedValueStoreImpl(t *testing.T) { }) } +func TestEncryptedValueMigration(t *testing.T) { + t.Parallel() + + sut := testutils.Setup(t) + tracer := noop.NewTracerProvider().Tracer("test") + usageStats := &usagestats.UsageStatsMock{T: t} + enc, err := cipherService.ProvideAESGCMCipherService(tracer, usageStats) + require.NoError(t, err) + + testCases := []struct { + namespace string + name string + version int64 + plaintext string + dataKeyId string + }{ + { + namespace: "test-namespace-1", + name: "test-name-1", + version: 1, + plaintext: "test-plaintext-1", + dataKeyId: "test-data-key-id-1", + }, + { + namespace: "test-namespace-1", + name: "test-name-2", + version: 1, + plaintext: "test-plaintext-2", + dataKeyId: "test-data-key-id-1", + }, + { + namespace: "test-namespace-2", + name: "test-name-3", + version: 1, + plaintext: "test-plaintext-3", + dataKeyId: "test-data-key-id-2", + }, + } + + // Seed with data in the legacy format + for _, tc := range testCases { + err := createLegacyEncryptedData(t, sut, enc, tc.namespace, tc.name, tc.version, tc.plaintext, tc.dataKeyId) + require.NoError(t, err) + } + + // Run the migration and blindy trust it + rowsAffected, err := sut.EncryptedValueMigrationExecutor.Execute(t.Context()) + require.NoError(t, err) + require.Equal(t, len(testCases), rowsAffected) + + // Now validate that the data is in the new format + encryptedValues, err := sut.GlobalEncryptedValueStorage.ListAll(t.Context(), contracts.ListOpts{}, nil) + require.NoError(t, err) + require.Len(t, encryptedValues, 3) + + for _, tc := range testCases { + ev, err := sut.EncryptedValueStorage.Get(t.Context(), tc.namespace, tc.name, tc.version) + require.NoError(t, err) + + // Decrypt the encrypted data and check for equality + decrypted, err := enc.Decrypt(t.Context(), ev.EncryptedData, tc.dataKeyId) + require.NoError(t, err) + require.Equal(t, tc.dataKeyId, ev.DataKeyID) + require.Equal(t, tc.plaintext, string(decrypted)) + } +} + +// Helper function that bypasses interfaces and creates data in the legacy format directly in the database. +// The format is "#{encoded_key_id}#{encrypted_data}". +func createLegacyEncryptedData(t *testing.T, sut testutils.Sut, enc cipher.Cipher, namespace, name string, version int64, plaintext string, dataKeyId string) error { + t.Helper() + + encryptedData, err := enc.Encrypt(t.Context(), []byte(plaintext), dataKeyId) + require.NoError(t, err) + + // Encode using the legacy format + const keyIdDelimiter = '#' + prefix := make([]byte, base64.RawStdEncoding.EncodedLen(len(dataKeyId))+2) + base64.RawStdEncoding.Encode(prefix[1:], []byte(dataKeyId)) + prefix[0] = keyIdDelimiter + prefix[len(prefix)-1] = keyIdDelimiter + + blob := make([]byte, len(prefix)+len(encryptedData)) + copy(blob, prefix) + copy(blob[len(prefix):], encryptedData) + + createdTime := time.Now().Unix() + + encryptedValue := &encryption.EncryptedValue{ + Namespace: namespace, + Name: name, + Version: version, + EncryptedData: blob, + DataKeyID: "", + Created: createdTime, + Updated: createdTime, + } + + req := struct { + sqltemplate.SQLTemplate + Row *encryption.EncryptedValue + }{ + SQLTemplate: sqltemplate.New(sqltemplate.DialectForDriver(sut.Database.DriverName())), + Row: encryptedValue, + } + tmpl, err := template.ParseFiles("data/encrypted_value_create.sql") + if err != nil { + return fmt.Errorf("parsing template: %w", err) + } + + query, err := sqltemplate.Execute(tmpl, req) + if err != nil { + return fmt.Errorf("executing template: %w", err) + } + + res, err := sut.Database.ExecContext(t.Context(), query, req.GetArgs()...) + if err != nil { + return fmt.Errorf("inserting row: %w", err) + } + + if rowsAffected, err := res.RowsAffected(); err != nil { + return fmt.Errorf("getting rows affected: %w", err) + } else if rowsAffected != 1 { + return fmt.Errorf("expected 1 row affected, got %d", rowsAffected) + } + + return nil +} + func TestStateMachine(t *testing.T) { t.Parallel() @@ -345,7 +482,7 @@ func (m *model) create(namespace, name string, version int64, encryptedData []by if err != nil && !errors.Is(err, encryption.ErrEncryptedValueNotFound) { return nil, err } - // The entry being creted already exists + // The entry being created already exists if v != nil { return nil, encryption.ErrEncryptedValueAlreadyExists }