diff --git a/common/core/network.go b/common/core/network.go index ce560450..0c397e28 100644 --- a/common/core/network.go +++ b/common/core/network.go @@ -219,10 +219,22 @@ func (plugin *netPlugin) Add(args *cniSkel.CmdArgs) (resultError error) { // Apply the Network Policy for Endpoint epInfo.Policies = append(epInfo.Policies, networkInfo.Policies...) - // If LoopbackDSR is set, add to policies + hnsIPAM := false + // If LoopbackDSR is set and IP is already allocated via IPAM, add to policies before endpoint creation if cniConfig.OptionalFlags.LoopbackDSR { - hcnLoopbackRoute, _ := network.GetLoopbackDSRPolicy(&epInfo.IPAddress) - epInfo.Policies = append(epInfo.Policies, hcnLoopbackRoute) + if len(epInfo.IPAddress) > 0 { + hcnLoopbackRoute, err := network.GetLoopbackDSRPolicy(epInfo.IPAddress) + if err != nil { + logrus.Errorf("[cni-net] Failed to create loopbackDSR policy: %v, IPAddress: %v", err, epInfo.IPAddress) + return err + } + logrus.Debugf("[cni-net] Created loopbackDSR policy for IP: %v", epInfo.IPAddress) + epInfo.Policies = append(epInfo.Policies, hcnLoopbackRoute) + } else { + // IP will be assigned by HCN, add the loopbackDSR policy after endpoint creation + hnsIPAM = true + logrus.Debugf("[cni-net] IP not yet assigned, add loopbackDSR policy after endpoint creation") + } } epInfo, err = plugin.nm.CreateEndpoint(nwConfig.ID, epInfo, args.Netns) @@ -231,6 +243,28 @@ func (plugin *netPlugin) Add(args *cniSkel.CmdArgs) (resultError error) { return err } + // If LoopbackDSR is set but IP wasn't allocated via IPAM, add the policy after endpoint creation + if hnsIPAM { + if len(epInfo.IPAddress) == 0 { + logrus.Errorf("[cni-net] LoopbackDSR is enabled but endpoint IP address is not set after endpoint creation") + return errors.New("loopbackDSR requires IP address to be allocated") + } + + logrus.Debugf("[cni-net] Adding loopbackDSR policy for IP: %v after endpoint creation", epInfo.IPAddress) + hcnLoopbackRoute, err := network.GetLoopbackDSRPolicy(epInfo.IPAddress) + if err != nil { + logrus.Errorf("[cni-net] Failed to create loopbackDSR policy after endpoint creation: %v, IPAddress: %v", err, epInfo.IPAddress) + return err + } + + err = plugin.nm.ApplyPolicy(epInfo.ID, hcnLoopbackRoute) + if err != nil { + logrus.Errorf("[cni-net] Failed to apply loopbackDSR policy to endpoint: %v", err) + return err + } + logrus.Debugf("[cni-net] Successfully applied loopbackDSR policy to endpoint") + } + // Convert result to the requested CNI version. res := cni.GetCurrResult(nwConfig, epInfo, args.IfName, cniConfig) result, err := res.GetAsVersion(cniConfig.CniVersion) diff --git a/network/manager.go b/network/manager.go index d9e1d6a7..73fc5c20 100644 --- a/network/manager.go +++ b/network/manager.go @@ -4,6 +4,7 @@ package network import ( + "encoding/json" "fmt" "github.com/Microsoft/hcsshim/hcn" "github.com/Microsoft/windows-container-networking/common" @@ -33,6 +34,7 @@ type Manager interface { DeleteEndpoint(endpointID string) error GetEndpoint(endpointID string, withIpv6 bool) (*EndpointInfo, error) GetEndpointByName(endpointName string, withIpv6 bool) (*EndpointInfo, error) + ApplyPolicy(endpointID string, policy Policy) error } // NewManager creates a new networkManager. @@ -210,3 +212,31 @@ func (nm *networkManager) GetEndpointByName(endpointName string, withIpv6 bool) return GetEndpointInfoFromHostComputeEndpoint(hcnEndpoint, withIpv6), nil } + +// ApplyPolicy applies a policy to an existing endpoint. +func (nm *networkManager) ApplyPolicy(endpointID string, policy Policy) error { + nm.Lock() + defer nm.Unlock() + + hcnEndpoint, err := hcn.GetEndpointByID(endpointID) + if err != nil { + return fmt.Errorf("failed to get endpoint %s: %v", endpointID, err) + } + + var endpointPolicy hcn.EndpointPolicy + err = json.Unmarshal(policy.Data, &endpointPolicy) + if err != nil { + return fmt.Errorf("failed to unmarshal policy: %v", err) + } + + policyRequest := hcn.PolicyEndpointRequest{ + Policies: []hcn.EndpointPolicy{endpointPolicy}, + } + + err = hcnEndpoint.ApplyPolicy(hcn.RequestTypeAdd, policyRequest) + if err != nil { + return fmt.Errorf("failed to apply policy to endpoint: %v", err) + } + + return nil +} diff --git a/network/policy.go b/network/policy.go index 2212106a..45ccfd04 100644 --- a/network/policy.go +++ b/network/policy.go @@ -6,6 +6,7 @@ package network import ( "encoding/json" "errors" + "fmt" "net" "strconv" "strings" @@ -81,7 +82,10 @@ func GetPortMappingPolicy(externalPort int, internalPort int, protocol string, h } // GetLoopbackDSRPolicy creates a policy to support loopback direct server return. -func GetLoopbackDSRPolicy(ip *net.IP) (Policy, error) { +func GetLoopbackDSRPolicy(ip net.IP) (Policy, error) { + if len(ip) == 0 { + return Policy{}, fmt.Errorf("IP address cannot be empty for loopbackDSR policy") + } hcnLoopbackRoute := hcn.OutboundNatPolicySetting{ Destinations: []string{ip.String()}, }