From 2a32f0bacd5ba3a3c5324e001556c62d785d9d02 Mon Sep 17 00:00:00 2001 From: Andreas Bergmeier Date: Fri, 7 Jan 2022 23:09:21 +0100 Subject: [PATCH] Allow processing of common options from FlagSet In situations where you don't want/need Cobra climbing behavior nor Cobra at all using FlagSet is the easier sell. Signed-off-by: Andreas Bergmeier --- pkg/parse/parse.go | 54 +++++++++++++++++++++++------------------ pkg/parse/parse_test.go | 15 ++++++++++++ 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/pkg/parse/parse.go b/pkg/parse/parse.go index dc90d441e..658ce1c60 100644 --- a/pkg/parse/parse.go +++ b/pkg/parse/parse.go @@ -28,6 +28,7 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "github.com/spf13/pflag" "golang.org/x/term" ) @@ -53,6 +54,11 @@ var ( // CommonBuildOptions parses the build options from the bud cli func CommonBuildOptions(c *cobra.Command) (*define.CommonBuildOptions, error) { + return CommonBuildOptionsFromFlagSet(c.Flags(), c.Flag) +} + +// CommonBuildOptionsFromFlagSet parses the build options from the bud cli +func CommonBuildOptionsFromFlagSet(flags *pflag.FlagSet, findFlagFunc func(name string) *pflag.Flag) (*define.CommonBuildOptions, error) { var ( memoryLimit int64 memorySwap int64 @@ -60,7 +66,7 @@ func CommonBuildOptions(c *cobra.Command) (*define.CommonBuildOptions, error) { err error ) - memVal, _ := c.Flags().GetString("memory") + memVal, _ := flags.GetString("memory") if memVal != "" { memoryLimit, err = units.RAMInBytes(memVal) if err != nil { @@ -68,7 +74,7 @@ func CommonBuildOptions(c *cobra.Command) (*define.CommonBuildOptions, error) { } } - memSwapValue, _ := c.Flags().GetString("memory-swap") + memSwapValue, _ := flags.GetString("memory-swap") if memSwapValue != "" { if memSwapValue == "-1" { memorySwap = -1 @@ -80,7 +86,7 @@ func CommonBuildOptions(c *cobra.Command) (*define.CommonBuildOptions, error) { } } - addHost, _ := c.Flags().GetStringSlice("add-host") + addHost, _ := flags.GetStringSlice("add-host") if len(addHost) > 0 { for _, host := range addHost { if err := validateExtraHost(host); err != nil { @@ -91,8 +97,8 @@ func CommonBuildOptions(c *cobra.Command) (*define.CommonBuildOptions, error) { noDNS = false dnsServers := []string{} - if c.Flag("dns").Changed { - dnsServers, _ = c.Flags().GetStringSlice("dns") + if flags.Changed("dns") { + dnsServers, _ = flags.GetStringSlice("dns") for _, server := range dnsServers { if strings.ToLower(server) == "none" { noDNS = true @@ -104,62 +110,62 @@ func CommonBuildOptions(c *cobra.Command) (*define.CommonBuildOptions, error) { } dnsSearch := []string{} - if c.Flag("dns-search").Changed { - dnsSearch, _ = c.Flags().GetStringSlice("dns-search") + if flags.Changed("dns-search") { + dnsSearch, _ = flags.GetStringSlice("dns-search") if noDNS && len(dnsSearch) > 0 { return nil, errors.Errorf("invalid --dns-search, --dns-search may not be used with --dns=none") } } dnsOptions := []string{} - if c.Flag("dns-option").Changed { - dnsOptions, _ = c.Flags().GetStringSlice("dns-option") + if flags.Changed("dns-option") { + dnsOptions, _ = flags.GetStringSlice("dns-option") if noDNS && len(dnsOptions) > 0 { return nil, errors.Errorf("invalid --dns-option, --dns-option may not be used with --dns=none") } } - if _, err := units.FromHumanSize(c.Flag("shm-size").Value.String()); err != nil { + if _, err := units.FromHumanSize(findFlagFunc("shm-size").Value.String()); err != nil { return nil, errors.Wrapf(err, "invalid --shm-size") } - volumes, _ := c.Flags().GetStringArray("volume") + volumes, _ := flags.GetStringArray("volume") if err := Volumes(volumes); err != nil { return nil, err } - cpuPeriod, _ := c.Flags().GetUint64("cpu-period") - cpuQuota, _ := c.Flags().GetInt64("cpu-quota") - cpuShares, _ := c.Flags().GetUint64("cpu-shares") - httpProxy, _ := c.Flags().GetBool("http-proxy") + cpuPeriod, _ := flags.GetUint64("cpu-period") + cpuQuota, _ := flags.GetInt64("cpu-quota") + cpuShares, _ := flags.GetUint64("cpu-shares") + httpProxy, _ := flags.GetBool("http-proxy") ulimit := []string{} - if c.Flag("ulimit").Changed { - ulimit, _ = c.Flags().GetStringSlice("ulimit") + if flags.Changed("ulimit") { + ulimit, _ = flags.GetStringSlice("ulimit") } - secrets, _ := c.Flags().GetStringArray("secret") - sshsources, _ := c.Flags().GetStringArray("ssh") + secrets, _ := flags.GetStringArray("secret") + sshsources, _ := flags.GetStringArray("ssh") commonOpts := &define.CommonBuildOptions{ AddHost: addHost, CPUPeriod: cpuPeriod, CPUQuota: cpuQuota, - CPUSetCPUs: c.Flag("cpuset-cpus").Value.String(), - CPUSetMems: c.Flag("cpuset-mems").Value.String(), + CPUSetCPUs: findFlagFunc("cpuset-cpus").Value.String(), + CPUSetMems: findFlagFunc("cpuset-mems").Value.String(), CPUShares: cpuShares, - CgroupParent: c.Flag("cgroup-parent").Value.String(), + CgroupParent: findFlagFunc("cgroup-parent").Value.String(), DNSOptions: dnsOptions, DNSSearch: dnsSearch, DNSServers: dnsServers, HTTPProxy: httpProxy, Memory: memoryLimit, MemorySwap: memorySwap, - ShmSize: c.Flag("shm-size").Value.String(), + ShmSize: findFlagFunc("shm-size").Value.String(), Ulimit: ulimit, Volumes: volumes, Secrets: secrets, SSHSources: sshsources, } - securityOpts, _ := c.Flags().GetStringArray("security-opt") + securityOpts, _ := flags.GetStringArray("security-opt") if err := parseSecurityOpts(securityOpts, commonOpts); err != nil { return nil, err } diff --git a/pkg/parse/parse_test.go b/pkg/parse/parse_test.go index 5e96c2acc..55b848f26 100644 --- a/pkg/parse/parse_test.go +++ b/pkg/parse/parse_test.go @@ -5,9 +5,24 @@ import ( "runtime" "testing" + "github.com/spf13/pflag" "github.com/stretchr/testify/assert" ) +func TestCommonBuildOptionsFromFlagSet(t *testing.T) { + fs := pflag.NewFlagSet("testme", pflag.PanicOnError) + fs.String("memory", "1GB", "") + fs.String("shm-size", "5TB", "") + fs.String("cpuset-cpus", "1", "") + fs.String("cpuset-mems", "2", "") + fs.String("cgroup-parent", "none", "") + err := fs.Parse([]string{"--memory", "2GB"}) + assert.NoError(t, err) + cbo, err := CommonBuildOptionsFromFlagSet(fs, fs.Lookup) + assert.NoError(t, err) + assert.Equal(t, cbo.Memory, int64(2147483648)) +} + // TestDeviceParser verifies the given device strings is parsed correctly func TestDeviceParser(t *testing.T) { if runtime.GOOS != "linux" {