mirror of https://github.com/grafana/grafana.git
				
				
				
			
		
			
				
	
	
		
			333 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Go
		
	
	
	
			
		
		
	
	
			333 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Go
		
	
	
	
| package postgres
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"path/filepath"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"testing"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/grafana/grafana-plugin-sdk-go/backend"
 | |
| 	"github.com/grafana/grafana-plugin-sdk-go/backend/log"
 | |
| 	"github.com/grafana/grafana/pkg/setting"
 | |
| 	"github.com/grafana/grafana/pkg/tsdb/grafana-postgresql-datasource/sqleng"
 | |
| 	"github.com/stretchr/testify/assert"
 | |
| 	"github.com/stretchr/testify/require"
 | |
| 
 | |
| 	_ "github.com/lib/pq"
 | |
| )
 | |
| 
 | |
| var writeCertFileCallNum int
 | |
| 
 | |
| // TestDataSourceCacheManager is to test the Cache manager
 | |
| func TestDataSourceCacheManager(t *testing.T) {
 | |
| 	cfg := setting.NewCfg()
 | |
| 	cfg.DataPath = t.TempDir()
 | |
| 	mng := tlsManager{
 | |
| 		logger:          backend.NewLoggerWith("logger", "tsdb.postgres"),
 | |
| 		dsCacheInstance: datasourceCacheManager{locker: newLocker()},
 | |
| 		dataPath:        cfg.DataPath,
 | |
| 	}
 | |
| 	jsonData := sqleng.JsonData{
 | |
| 		Mode:                "verify-full",
 | |
| 		ConfigurationMethod: "file-content",
 | |
| 	}
 | |
| 	secureJSONData := map[string]string{
 | |
| 		"tlsClientCert": "I am client certification",
 | |
| 		"tlsClientKey":  "I am client key",
 | |
| 		"tlsCACert":     "I am CA certification",
 | |
| 	}
 | |
| 
 | |
| 	updateTime := time.Now().Add(-5 * time.Minute)
 | |
| 
 | |
| 	mockValidateCertFilePaths()
 | |
| 	t.Cleanup(resetValidateCertFilePaths)
 | |
| 
 | |
| 	t.Run("Check datasource cache creation", func(t *testing.T) {
 | |
| 		var wg sync.WaitGroup
 | |
| 		wg.Add(10)
 | |
| 		for id := int64(1); id <= 10; id++ {
 | |
| 			go func(id int64) {
 | |
| 				ds := sqleng.DataSourceInfo{
 | |
| 					ID:                      id,
 | |
| 					Updated:                 updateTime,
 | |
| 					Database:                "database",
 | |
| 					JsonData:                jsonData,
 | |
| 					DecryptedSecureJSONData: secureJSONData,
 | |
| 					UID:                     "testData",
 | |
| 				}
 | |
| 				s := tlsSettings{}
 | |
| 				err := mng.writeCertFiles(ds, &s)
 | |
| 				require.NoError(t, err)
 | |
| 				wg.Done()
 | |
| 			}(id)
 | |
| 		}
 | |
| 		wg.Wait()
 | |
| 
 | |
| 		t.Run("check cache creation is succeed", func(t *testing.T) {
 | |
| 			for id := int64(1); id <= 10; id++ {
 | |
| 				updated, ok := mng.dsCacheInstance.cache.Load(strconv.Itoa(int(id)))
 | |
| 				require.True(t, ok)
 | |
| 				require.Equal(t, updateTime, updated)
 | |
| 			}
 | |
| 		})
 | |
| 	})
 | |
| 
 | |
| 	t.Run("Check datasource cache modification", func(t *testing.T) {
 | |
| 		t.Run("check when version not changed, cache and files are not updated", func(t *testing.T) {
 | |
| 			mockWriteCertFile()
 | |
| 			t.Cleanup(resetWriteCertFile)
 | |
| 			var wg1 sync.WaitGroup
 | |
| 			wg1.Add(5)
 | |
| 			for id := int64(1); id <= 5; id++ {
 | |
| 				go func(id int64) {
 | |
| 					ds := sqleng.DataSourceInfo{
 | |
| 						ID:                      1,
 | |
| 						Updated:                 updateTime,
 | |
| 						Database:                "database",
 | |
| 						JsonData:                jsonData,
 | |
| 						DecryptedSecureJSONData: secureJSONData,
 | |
| 						UID:                     "testData",
 | |
| 					}
 | |
| 					s := tlsSettings{}
 | |
| 					err := mng.writeCertFiles(ds, &s)
 | |
| 					require.NoError(t, err)
 | |
| 					wg1.Done()
 | |
| 				}(id)
 | |
| 			}
 | |
| 			wg1.Wait()
 | |
| 			assert.Equal(t, writeCertFileCallNum, 0)
 | |
| 		})
 | |
| 
 | |
| 		t.Run("cache is updated with the last datasource version", func(t *testing.T) {
 | |
| 			dsV2 := sqleng.DataSourceInfo{
 | |
| 				ID:                      1,
 | |
| 				Updated:                 updateTime.Add(time.Minute),
 | |
| 				Database:                "database",
 | |
| 				JsonData:                jsonData,
 | |
| 				DecryptedSecureJSONData: secureJSONData,
 | |
| 				UID:                     "testData",
 | |
| 			}
 | |
| 			dsV3 := sqleng.DataSourceInfo{
 | |
| 				ID:                      1,
 | |
| 				Updated:                 updateTime.Add(2 * time.Minute),
 | |
| 				Database:                "database",
 | |
| 				JsonData:                jsonData,
 | |
| 				DecryptedSecureJSONData: secureJSONData,
 | |
| 				UID:                     "testData",
 | |
| 			}
 | |
| 			s := tlsSettings{}
 | |
| 			err := mng.writeCertFiles(dsV2, &s)
 | |
| 			require.NoError(t, err)
 | |
| 			err = mng.writeCertFiles(dsV3, &s)
 | |
| 			require.NoError(t, err)
 | |
| 			version, ok := mng.dsCacheInstance.cache.Load("1")
 | |
| 			require.True(t, ok)
 | |
| 			require.Equal(t, updateTime.Add(2*time.Minute), version)
 | |
| 		})
 | |
| 	})
 | |
| }
 | |
| 
 | |
| // Test getFileName
 | |
| 
 | |
| func TestGetFileName(t *testing.T) {
 | |
| 	testCases := []struct {
 | |
| 		desc                  string
 | |
| 		datadir               string
 | |
| 		fileType              certFileType
 | |
| 		expErr                string
 | |
| 		expectedGeneratedPath string
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:                  "Get File Name for root certification",
 | |
| 			datadir:               ".",
 | |
| 			fileType:              rootCert,
 | |
| 			expectedGeneratedPath: "root.crt",
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                  "Get File Name for client certification",
 | |
| 			datadir:               ".",
 | |
| 			fileType:              clientCert,
 | |
| 			expectedGeneratedPath: "client.crt",
 | |
| 		},
 | |
| 		{
 | |
| 			desc:                  "Get File Name for client certification",
 | |
| 			datadir:               ".",
 | |
| 			fileType:              clientKey,
 | |
| 			expectedGeneratedPath: "client.key",
 | |
| 		},
 | |
| 	}
 | |
| 	for _, tt := range testCases {
 | |
| 		t.Run(tt.desc, func(t *testing.T) {
 | |
| 			generatedPath := getFileName(tt.datadir, tt.fileType)
 | |
| 			assert.Equal(t, tt.expectedGeneratedPath, generatedPath)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Test getTLSSettings.
 | |
| func TestGetTLSSettings(t *testing.T) {
 | |
| 	cfg := setting.NewCfg()
 | |
| 	cfg.DataPath = t.TempDir()
 | |
| 
 | |
| 	mockValidateCertFilePaths()
 | |
| 	t.Cleanup(resetValidateCertFilePaths)
 | |
| 
 | |
| 	updatedTime := time.Now()
 | |
| 
 | |
| 	testCases := []struct {
 | |
| 		desc           string
 | |
| 		expErr         string
 | |
| 		jsonData       sqleng.JsonData
 | |
| 		secureJSONData map[string]string
 | |
| 		uid            string
 | |
| 		tlsSettings    tlsSettings
 | |
| 		updated        time.Time
 | |
| 	}{
 | |
| 		{
 | |
| 			desc:    "Custom TLS authentication disabled",
 | |
| 			updated: updatedTime,
 | |
| 			jsonData: sqleng.JsonData{
 | |
| 				Mode:                "disable",
 | |
| 				RootCertFile:        "i/am/coding/ca.crt",
 | |
| 				CertFile:            "i/am/coding/client.crt",
 | |
| 				CertKeyFile:         "i/am/coding/client.key",
 | |
| 				ConfigurationMethod: "file-path",
 | |
| 			},
 | |
| 			tlsSettings: tlsSettings{Mode: "disable"},
 | |
| 		},
 | |
| 		{
 | |
| 			desc:    "Custom TLS authentication with file path",
 | |
| 			updated: updatedTime.Add(time.Minute),
 | |
| 			jsonData: sqleng.JsonData{
 | |
| 				Mode:                "verify-full",
 | |
| 				ConfigurationMethod: "file-path",
 | |
| 				RootCertFile:        "i/am/coding/ca.crt",
 | |
| 				CertFile:            "i/am/coding/client.crt",
 | |
| 				CertKeyFile:         "i/am/coding/client.key",
 | |
| 			},
 | |
| 			tlsSettings: tlsSettings{
 | |
| 				Mode:                "verify-full",
 | |
| 				ConfigurationMethod: "file-path",
 | |
| 				RootCertFile:        "i/am/coding/ca.crt",
 | |
| 				CertFile:            "i/am/coding/client.crt",
 | |
| 				CertKeyFile:         "i/am/coding/client.key",
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			desc:    "Custom TLS mode verify-full with certificate files content",
 | |
| 			updated: updatedTime.Add(2 * time.Minute),
 | |
| 			uid:     "xxx",
 | |
| 			jsonData: sqleng.JsonData{
 | |
| 				Mode:                "verify-full",
 | |
| 				ConfigurationMethod: "file-content",
 | |
| 			},
 | |
| 			secureJSONData: map[string]string{
 | |
| 				"tlsCACert":     "I am CA certification",
 | |
| 				"tlsClientCert": "I am client certification",
 | |
| 				"tlsClientKey":  "I am client key",
 | |
| 			},
 | |
| 			tlsSettings: tlsSettings{
 | |
| 				Mode:                "verify-full",
 | |
| 				ConfigurationMethod: "file-content",
 | |
| 				RootCertFile:        filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"),
 | |
| 				CertFile:            filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"),
 | |
| 				CertKeyFile:         filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"),
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			desc:    "Custom TLS mode verify-ca with no client certificates with certificate files content",
 | |
| 			updated: updatedTime.Add(3 * time.Minute),
 | |
| 			uid:     "xxx",
 | |
| 			jsonData: sqleng.JsonData{
 | |
| 				Mode:                "verify-ca",
 | |
| 				ConfigurationMethod: "file-content",
 | |
| 			},
 | |
| 			secureJSONData: map[string]string{
 | |
| 				"tlsCACert": "I am CA certification",
 | |
| 			},
 | |
| 			tlsSettings: tlsSettings{
 | |
| 				Mode:                "verify-ca",
 | |
| 				ConfigurationMethod: "file-content",
 | |
| 				RootCertFile:        filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "root.crt"),
 | |
| 				CertFile:            "",
 | |
| 				CertKeyFile:         "",
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			desc:    "Custom TLS mode require with client certificates and no root certificate with certificate files content",
 | |
| 			updated: updatedTime.Add(4 * time.Minute),
 | |
| 			uid:     "xxx",
 | |
| 			jsonData: sqleng.JsonData{
 | |
| 				Mode:                "require",
 | |
| 				ConfigurationMethod: "file-content",
 | |
| 			},
 | |
| 			secureJSONData: map[string]string{
 | |
| 				"tlsClientCert": "I am client certification",
 | |
| 				"tlsClientKey":  "I am client key",
 | |
| 			},
 | |
| 			tlsSettings: tlsSettings{
 | |
| 				Mode:                "require",
 | |
| 				ConfigurationMethod: "file-content",
 | |
| 				RootCertFile:        "",
 | |
| 				CertFile:            filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.crt"),
 | |
| 				CertKeyFile:         filepath.Join(cfg.DataPath, "tls", "xxxgeneratedTLSCerts", "client.key"),
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 	for _, tt := range testCases {
 | |
| 		t.Run(tt.desc, func(t *testing.T) {
 | |
| 			var settings tlsSettings
 | |
| 			var err error
 | |
| 			mng := tlsManager{
 | |
| 				logger:          backend.NewLoggerWith("logger", "tsdb.postgres"),
 | |
| 				dsCacheInstance: datasourceCacheManager{locker: newLocker()},
 | |
| 				dataPath:        cfg.DataPath,
 | |
| 			}
 | |
| 
 | |
| 			ds := sqleng.DataSourceInfo{
 | |
| 				JsonData:                tt.jsonData,
 | |
| 				DecryptedSecureJSONData: tt.secureJSONData,
 | |
| 				UID:                     tt.uid,
 | |
| 				Updated:                 tt.updated,
 | |
| 			}
 | |
| 
 | |
| 			settings, err = mng.getTLSSettings(ds)
 | |
| 
 | |
| 			if tt.expErr == "" {
 | |
| 				require.NoError(t, err, tt.desc)
 | |
| 				assert.Equal(t, tt.tlsSettings, settings)
 | |
| 			} else {
 | |
| 				require.Error(t, err, tt.desc)
 | |
| 				assert.True(t, strings.HasPrefix(err.Error(), tt.expErr),
 | |
| 					fmt.Sprintf("%s: %q doesn't start with %q", tt.desc, err, tt.expErr))
 | |
| 			}
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func mockValidateCertFilePaths() {
 | |
| 	validateCertFunc = func(rootCert, clientCert, clientKey string) error {
 | |
| 		return nil
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func resetValidateCertFilePaths() {
 | |
| 	validateCertFunc = validateCertFilePaths
 | |
| }
 | |
| 
 | |
| func mockWriteCertFile() {
 | |
| 	writeCertFileCallNum = 0
 | |
| 	writeCertFileFunc = func(logger log.Logger, fileContent string, generatedFilePath string) error {
 | |
| 		writeCertFileCallNum++
 | |
| 		return nil
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func resetWriteCertFile() {
 | |
| 	writeCertFileCallNum = 0
 | |
| 	writeCertFileFunc = writeCertFile
 | |
| }
 |