diff --git a/cmd/rdma/main.go b/cmd/rdma/main.go index e0fcd30..84fc932 100644 --- a/cmd/rdma/main.go +++ b/cmd/rdma/main.go @@ -3,12 +3,11 @@ package main import ( "encoding/json" "fmt" - "os" - "github.com/Mellanox/rdma-cni/pkg/cache" "github.com/Mellanox/rdma-cni/pkg/rdma" rdmatypes "github.com/Mellanox/rdma-cni/pkg/types" "github.com/Mellanox/rdma-cni/pkg/utils" + "os" "github.com/containernetworking/cni/pkg/skel" "github.com/containernetworking/cni/pkg/types" @@ -103,9 +102,11 @@ func (plugin *rdmaCniPlugin) parseConf(data []byte, envArgs string) (*rdmatypes. return &conf, nil } -// Move RDMA device to namespace -func (plugin *rdmaCniPlugin) moveRdmaDevToNs(rdmaDev string, nsPath string) error { - log.Debug().Msgf("moving RDMA device %s to namespace %s", rdmaDev, nsPath) +// Move RDMA device, sRdmaDev, to namespace and rename RDMA device to cRdmadev +func (plugin *rdmaCniPlugin) moveRdmaDevToNs(sRdmaDev string, cRdmaDev string, nsPath string) error { + log.Debug().Msgf("Moving RDMA device %s to namespace %s", sRdmaDev, nsPath) + var err error + renameReq := sRdmaDev != cRdmaDev targetNs, err := plugin.nsManager.GetNS(nsPath) if err != nil { @@ -113,16 +114,51 @@ func (plugin *rdmaCniPlugin) moveRdmaDevToNs(rdmaDev string, nsPath string) erro } defer targetNs.Close() - err = plugin.rdmaManager.MoveRdmaDevToNs(rdmaDev, targetNs) + tmpName := sRdmaDev + if renameReq { + // set temp name for RDMA dev + tmpName, err = plugin.rdmaManager.SetRdmaDevTempName(sRdmaDev) + if err != nil { + return err + } + defer func() { + if err != nil { + log.Warn().Msgf("Error occured while moving RDMA device to namespace. %v", err) + restoreErr := plugin.rdmaManager.SetRdmaDevName(tmpName, sRdmaDev) + if restoreErr != nil { + log.Warn().Msgf("Failed to restore RDMA device name. %v", restoreErr) + } + } + }() + } + + err = plugin.rdmaManager.MoveRdmaDevToNs(tmpName, targetNs) if err != nil { - return fmt.Errorf("failed to move RDMA device %s to namespace. %v", rdmaDev, err) + return fmt.Errorf("failed to move RDMA device %s to namespace. %v", tmpName, err) } - return nil + + if renameReq { + // rename RDMA dev in container NS to target name + err = targetNs.Do(func(hostNs ns.NetNS) error { + renameErr := plugin.rdmaManager.SetRdmaDevName(tmpName, cRdmaDev) + if renameErr != nil { + // move RDMA device back to host namespace. + restoreErr := plugin.rdmaManager.MoveRdmaDevToNs(tmpName, hostNs) + if restoreErr != nil { + log.Warn().Msgf("Failed to move RDMA device to default namespace after error. %v", restoreErr) + } + } + return renameErr + }) + } + return err } -// Move RDMA device from namespace to current (default) namespace -func (plugin *rdmaCniPlugin) moveRdmaDevFromNs(rdmaDev string, nsPath string) error { - log.Debug().Msgf("INFO: moving RDMA device %s from namespace %s to default namespace", rdmaDev, nsPath) +// Move RDMA device, cRdmaDev, from namespace to current (default) namespace and rename it to sRdmaDev +func (plugin *rdmaCniPlugin) moveRdmaDevFromNs(cRdmaDev string, sRdmaDev string, nsPath string) error { + log.Debug().Msgf("Moving RDMA device %s from namespace %s to default namespace", cRdmaDev, nsPath) + var err error + renameReq := cRdmaDev != sRdmaDev sourceNs, err := plugin.nsManager.GetNS(nsPath) if err != nil { @@ -136,16 +172,65 @@ func (plugin *rdmaCniPlugin) moveRdmaDevFromNs(rdmaDev string, nsPath string) er } defer targetNs.Close() + var tmpName string err = sourceNs.Do(func(_ ns.NetNS) error { - // Move RDMA device to default namespace - return plugin.rdmaManager.MoveRdmaDevToNs(rdmaDev, targetNs) + if renameReq { + // Move RDMA device to default namespace + var sourceNsError error + tmpName, sourceNsError = plugin.rdmaManager.SetRdmaDevTempName(cRdmaDev) + if sourceNsError != nil { + log.Warn().Msgf("Failed to restore RDMA device name to its original value. %v", sourceNsError) + return plugin.rdmaManager.MoveRdmaDevToNs(cRdmaDev, targetNs) + } + } + return plugin.rdmaManager.MoveRdmaDevToNs(tmpName, targetNs) }) if err != nil { - return fmt.Errorf("failed to move RDMA device %s to default namespace. %v", rdmaDev, err) + return fmt.Errorf("failed to move RDMA device %s to default namespace. %v", cRdmaDev, err) + } + if renameReq { + // set target name + err = targetNs.Do(func(_ ns.NetNS) error { + return plugin.rdmaManager.SetRdmaDevName(tmpName, sRdmaDev) + }) } return err } +func (plugin *rdmaCniPlugin) getContainerRdmaDeviceName(sRdmaDev string, nsPath string) string { + var err error + var sourceNs ns.NetNS + sourceNs, err = plugin.nsManager.GetNS(nsPath) + defer sourceNs.Close() + defer func() { + if err != nil { + log.Warn().Msgf("Failed to generate container RDMA device name, %s. Original name will be used.", err) + } + }() + + var cRdmaDevs []string + err = sourceNs.Do(func(_ ns.NetNS) error { + var cErr error + cRdmaDevs, cErr = plugin.rdmaManager.GetRdmaDeviceList() + return cErr + }) + if err != nil { + return sRdmaDev + } + + var prefix string + prefix, err = utils.GetRdmaDevicePrefix(sRdmaDev) + if err != nil { + return sRdmaDev + } + + cNextRdmaDev, err := utils.GetNextRdmaDeviceName(prefix, cRdmaDevs) + if err != nil { + return sRdmaDev + } + return cNextRdmaDev +} + func (plugin *rdmaCniPlugin) CmdAdd(args *skel.CmdArgs) error { log.Info().Msgf("RDMA-CNI: cmdAdd") conf, err := plugin.parseConf(args.StdinData, args.Args) @@ -197,25 +282,26 @@ func (plugin *rdmaCniPlugin) CmdAdd(args *skel.CmdArgs) error { } // Move RDMA device to container namespace - rdmaDev := rdmaDevs[0] + sRdmaDev := rdmaDevs[0] + cRdmaDev := plugin.getContainerRdmaDeviceName(sRdmaDev, args.Netns) + log.Debug().Msgf("Sandbox RDMA device: %s, Container RDMA device: %s", sRdmaDev, cRdmaDev) - err = plugin.moveRdmaDevToNs(rdmaDev, args.Netns) + err = plugin.moveRdmaDevToNs(sRdmaDev, cRdmaDev, args.Netns) if err != nil { - return fmt.Errorf("failed to move RDMA device %s to namespace. %v", rdmaDev, err) + return fmt.Errorf("failed to move RDMA device %s to namespace. %v", sRdmaDev, err) } - // Save RDMA state state := rdmatypes.NewRdmaNetState() state.DeviceID = conf.DeviceID - state.SandboxRdmaDevName = rdmaDev - state.ContainerRdmaDevName = rdmaDev + state.SandboxRdmaDevName = sRdmaDev + state.ContainerRdmaDevName = cRdmaDev pRef := plugin.stateCache.GetStateRef(conf.Name, args.ContainerID, args.IfName) err = plugin.stateCache.Save(pRef, &state) if err != nil { // Move RDMA dev back to current namespace - restoreErr := plugin.moveRdmaDevFromNs(state.ContainerRdmaDevName, args.Netns) + restoreErr := plugin.moveRdmaDevFromNs(state.ContainerRdmaDevName, state.SandboxRdmaDevName, args.Netns) if restoreErr != nil { - return fmt.Errorf("save to cache failed %v, failed while restoring namespace for RDMA device %s. %v", err, rdmaDev, restoreErr) + return fmt.Errorf("save to cache failed %v, failed while restoring namespace for RDMA device %s. %v", err, sRdmaDev, restoreErr) } return err } @@ -249,11 +335,12 @@ func (plugin *rdmaCniPlugin) CmdDel(args *skel.CmdArgs) error { pRef := plugin.stateCache.GetStateRef(conf.Name, args.ContainerID, args.IfName) err = plugin.stateCache.Load(pRef, &rdmaState) if err != nil { - return err + log.Warn().Msgf("Failed to load state from cache, %v. preceding cmdAdd operation may have failed early.", err) + return nil } // Move RDMA device to default namespace - err = plugin.moveRdmaDevFromNs(rdmaState.ContainerRdmaDevName, args.Netns) + err = plugin.moveRdmaDevFromNs(rdmaState.ContainerRdmaDevName, rdmaState.SandboxRdmaDevName, args.Netns) if err != nil { return fmt.Errorf( "failed to restore RDMA device %s to default namespace. %v", rdmaState.ContainerRdmaDevName, err) diff --git a/pkg/rdma/rdma.go b/pkg/rdma/rdma.go index e828fa8..234b42e 100644 --- a/pkg/rdma/rdma.go +++ b/pkg/rdma/rdma.go @@ -2,6 +2,7 @@ package rdma import ( "fmt" + "syscall" "github.com/containernetworking/plugins/pkg/ns" ) @@ -18,12 +19,18 @@ func NewRdmaManager() RdmaManager { type RdmaManager interface { // Move RDMA device from current network namespace to network namespace MoveRdmaDevToNs(rdmaDev string, netNs ns.NetNS) error + // Get RDMA devices in the current network namespace + GetRdmaDeviceList() ([]string, error) // Get RDMA devices associated with the given PCI device in D:B:D.f format e.g 0000:04:00.0 GetRdmaDevsForPciDev(pciDev string) ([]string, error) // Get RDMA subsystem namespace awareness mode ["exclusive" | "shared"] GetSystemRdmaMode() (string, error) // Set RDMA subsystem namespace awareness mode ["exclusive" | "shared"] SetSystemRdmaMode(mode string) error + // Change RDMA device name + SetRdmaDevName(oldName string, newName string) error + // Set RDMA device temporary name + SetRdmaDevTempName(rdmaDev string) (string, error) } type rdmaManagerNetlink struct { @@ -57,3 +64,39 @@ func (rmn *rdmaManagerNetlink) GetSystemRdmaMode() (string, error) { func (rmn *rdmaManagerNetlink) SetSystemRdmaMode(mode string) error { return rmn.rdmaOps.RdmaSystemSetNetnsMode(mode) } + +// Change RDMA device name +func (rmn *rdmaManagerNetlink) SetRdmaDevName(oldName string, newName string) error { + rdmaLink, err := rmn.rdmaOps.RdmaLinkByName(oldName) + if err != nil { + return fmt.Errorf("cannot find RDMA link from name: %s", oldName) + } + err = rmn.rdmaOps.RdmaLinkSetName(rdmaLink, newName) + if err != nil { + return fmt.Errorf("failed to change RDMA device name from %s to %s. %v", oldName, newName, err) + } + return nil +} + +// Get RDMA devices in the current network namespace +func (rmn *rdmaManagerNetlink) GetRdmaDeviceList() ([]string, error) { + links, err := rmn.rdmaOps.GetRdmaLinkList() + if err != nil { + return nil, err + } + names := make([]string, len(links)) + for _, link := range links { + names = append(names, link.Attrs.Name) + } + return names, nil +} + +// Set RDMA device to a unique temporary name +func (rmn *rdmaManagerNetlink) SetRdmaDevTempName(rdmaDev string) (string, error) { + link, err := rmn.rdmaOps.RdmaLinkByName(rdmaDev) + if err != nil { + return "", err + } + tmpName := fmt.Sprintf("rdmadev_%d", link.Attrs.Index)[:(syscall.IFNAMSIZ - 1)] + return tmpName, rmn.rdmaOps.RdmaLinkSetName(link, tmpName) +} diff --git a/pkg/rdma/rdma_ops.go b/pkg/rdma/rdma_ops.go index eb799de..8fed360 100644 --- a/pkg/rdma/rdma_ops.go +++ b/pkg/rdma/rdma_ops.go @@ -11,12 +11,16 @@ type RdmaBasicOps interface { RdmaLinkByName(name string) (*netlink.RdmaLink, error) // Equivalent to netlink.RdmaLinkSetNsFd(...) RdmaLinkSetNsFd(link *netlink.RdmaLink, fd uint32) error + // Equivalent to netlink.RdmaLinkSetName(...) + RdmaLinkSetName(link *netlink.RdmaLink, name string) error // Equivalent to netlink.RdmaSystemGetNetnsMode(...) RdmaSystemGetNetnsMode() (string, error) // Equivalent to netlink.RdmaSystemSetNetnsMode(...) RdmaSystemSetNetnsMode(newMode string) error // Equivalent to rdmamap.GetRdmaDevicesForPcidev(...) GetRdmaDevicesForPcidev(pcidevName string) []string + // Equivalent to netlink.RdmaLinkList() + GetRdmaLinkList() ([]*netlink.RdmaLink, error) } func newRdmaBasicOps() RdmaBasicOps { @@ -36,6 +40,11 @@ func (rdma *rdmaBasicOpsImpl) RdmaLinkSetNsFd(link *netlink.RdmaLink, fd uint32) return netlink.RdmaLinkSetNsFd(link, fd) } +// Equivalent to netlink.RdmaLinkSetName(...) +func (rdma *rdmaBasicOpsImpl) RdmaLinkSetName(link *netlink.RdmaLink, name string) error { + return netlink.RdmaLinkSetName(link, name) +} + // Equivalent to netlink.RdmaSystemGetNetnsMode(...) func (rdma *rdmaBasicOpsImpl) RdmaSystemGetNetnsMode() (string, error) { return netlink.RdmaSystemGetNetnsMode() @@ -50,3 +59,8 @@ func (rdma *rdmaBasicOpsImpl) RdmaSystemSetNetnsMode(newMode string) error { func (rdma *rdmaBasicOpsImpl) GetRdmaDevicesForPcidev(pcidevName string) []string { return rdmamap.GetRdmaDevicesForPcidev(pcidevName) } + +// Equivalent to netlink.RdmaLinkList() +func (rdma *rdmaBasicOpsImpl) GetRdmaLinkList() ([]*netlink.RdmaLink, error) { + return netlink.RdmaLinkList() +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index c6427c8..c55350f 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -4,6 +4,8 @@ import ( "fmt" "path" "path/filepath" + "strconv" + "strings" "github.com/vishvananda/netlink" ) @@ -45,3 +47,51 @@ func GetVfPciDevFromMAC(mac string) (string, error) { } return dev, err } + +// Get RDMA device prefix and index. e.g for mlx5_3: prefix is mlx5 and index is 3 +// Note: the index is not related to the kernel RDMA device index +func getRdmaDevNamePrefixIndex(rdmaDev string) (prefix string, idx uint64, err error) { + s := strings.Split(rdmaDev, `_`) + if len(s) != 2 { + return "", 0, fmt.Errorf("unexpeded RDMA device format: %s", rdmaDev) + } + prefix = s[0] + idx, err = strconv.ParseUint(s[1], 0, 32) + if err != nil { + err = fmt.Errorf("failed to parse RDMA device index: %s, %v", rdmaDev, err) + } + return prefix, idx, err +} + +func getRdmaDevIndexFromName(rdmaDev string) (uint64, error) { + _, idx, err := getRdmaDevNamePrefixIndex(rdmaDev) + return idx, err +} + +// Get the next RDMA device name for a given RDMA device prefix +func GetNextRdmaDeviceName(prefix string, currDevs []string) (string, error) { + var nextDevIdx uint64 + nextDevIdx = 0 + if len(currDevs) != 0 { + for _, dev := range currDevs { + if !strings.HasPrefix(dev, prefix) { + continue + } + // extract index + idx, err := getRdmaDevIndexFromName(dev) + if err != nil { + return "", err + } + if idx > nextDevIdx { + nextDevIdx = idx + 1 + } + } + } + return fmt.Sprintf("%s_%d", prefix, nextDevIdx), nil +} + +// Get RDMA device driver prefix. e.g for mlx5_3 the prefix would be mlx5 +func GetRdmaDevicePrefix(rdmaDev string) (string, error) { + prefix, _, err := getRdmaDevNamePrefixIndex(rdmaDev) + return prefix, err +}