Commit 5742704c authored by Thor Bjorgvinsson's avatar Thor Bjorgvinsson Committed by Daniel Robinson
Browse files

Once agent worker connects to termination channel, it should no longer listen for OS signals

cr: https://code.amazon.com/reviews/CR-54798789
parent a4ce79f1
...@@ -143,6 +143,9 @@ func startAgent(ssmAgent agent.ISSMAgent, context context.T) (err error) { ...@@ -143,6 +143,9 @@ func startAgent(ssmAgent agent.ISSMAgent, context context.T) (err error) {
} }
func blockUntilSignaled(log logger.T) { func blockUntilSignaled(log logger.T) {
// ssm-agent-worker will rely on the termination channel to be notified when to terminate
// the agent worker will listen to OS signals until termination channel is successfully connected
// Below channel will handle all machine initiated shutdown/reboot requests. // Below channel will handle all machine initiated shutdown/reboot requests.
// Set up channel on which to receive signal notifications. // Set up channel on which to receive signal notifications.
...@@ -155,12 +158,21 @@ func blockUntilSignaled(log logger.T) { ...@@ -155,12 +158,21 @@ func blockUntilSignaled(log logger.T) {
// Otherwise we will continue execution and exit the program. // Otherwise we will continue execution and exit the program.
signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGTERM) signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGTERM)
log.Debug("Listening to os signal until core agent termination channel is connected")
select { select {
case s := <-c: case s := <-c:
log.Info("ssm-agent-worker got signal:", s, " value:", s.Signal) log.Info("ssm-agent-worker got signal:", s, " value:", s.Signal)
case <-messageBusClient.RebootRequestChannel(): return
log.Info("Received core agent reboot signal") case <-messageBusClient.GetTerminationChannelConnectedChan():
log.Debug("ssm-agent-worker is connected to core agent termination channel, stopping OS signal listener")
} }
// Clean up OS signal listener
signal.Stop(c)
close(c)
log.Debug("Waiting for termination request from core agent")
<-messageBusClient.GetTerminationRequestChan()
} }
// Run as a single process. Used by Unix systems and when running agent from console. // Run as a single process. Used by Unix systems and when running agent from console.
......
...@@ -32,16 +32,18 @@ import ( ...@@ -32,16 +32,18 @@ import (
type IMessageBus interface { type IMessageBus interface {
ProcessHealthRequest() ProcessHealthRequest()
ProcessTerminationRequest() ProcessTerminationRequest()
RebootRequestChannel() chan bool GetTerminationRequestChan() chan bool
GetTerminationChannelConnectedChan() chan bool
} }
// MessageBus contains the ipc channel to communicate to core agent. // MessageBus contains the ipc channel to communicate to core agent.
// It contains a reboot request channel that agent listens to // It contains a reboot request channel that agent listens to
type MessageBus struct { type MessageBus struct {
context context.T context context.T
healthChannel channel.IChannel healthChannel channel.IChannel
terminationChannel channel.IChannel terminationChannel channel.IChannel
rebootRequest chan bool terminationRequestChannel chan bool
terminationChannelConnected chan bool
} }
// NewMessageBus creates a new instance of MessageBus // NewMessageBus creates a new instance of MessageBus
...@@ -50,10 +52,11 @@ func NewMessageBus(context context.T) *MessageBus { ...@@ -50,10 +52,11 @@ func NewMessageBus(context context.T) *MessageBus {
identity := context.Identity() identity := context.Identity()
channelCreator := channel.GetChannelCreator(log, context.AppConfig(), identity) channelCreator := channel.GetChannelCreator(log, context.AppConfig(), identity)
return &MessageBus{ return &MessageBus{
context: context, context: context,
healthChannel: channelCreator(log, identity), healthChannel: channelCreator(log, identity),
terminationChannel: channelCreator(log, identity), terminationChannel: channelCreator(log, identity),
rebootRequest: make(chan bool, 1), terminationRequestChannel: make(chan bool, 1),
terminationChannelConnected: make(chan bool, 1),
} }
} }
...@@ -74,7 +77,7 @@ func (bus *MessageBus) ProcessHealthRequest() { ...@@ -74,7 +77,7 @@ func (bus *MessageBus) ProcessHealthRequest() {
if bus.healthChannel.IsConnect() { if bus.healthChannel.IsConnect() {
if err = bus.healthChannel.Close(); err != nil { if err = bus.healthChannel.Close(); err != nil {
bus.context.Log().Errorf("failed to close ipc channel: %v", err) bus.context.Log().Errorf("failed to close health channel: %v", err)
} }
} }
}() }()
...@@ -84,7 +87,7 @@ func (bus *MessageBus) ProcessHealthRequest() { ...@@ -84,7 +87,7 @@ func (bus *MessageBus) ProcessHealthRequest() {
if err = bus.dialToCoreAgentChannel(message.GetWorkerHealthRequest, message.GetWorkerHealthChannel); err != nil { if err = bus.dialToCoreAgentChannel(message.GetWorkerHealthRequest, message.GetWorkerHealthChannel); err != nil {
// This happens when worker started before core agent is // This happens when worker started before core agent is
// In practise, it should never happen // In practise, it should never happen
log.Errorf("failed to listen to Core Agent broadcast channel: %s", err.Error()) log.Errorf("failed to listen to Core Agent health channel: %s", err.Error())
time.Sleep(time.Duration(bus.context.AppConfig().Ssm.HealthFrequencyMinutes) * time.Minute) time.Sleep(time.Duration(bus.context.AppConfig().Ssm.HealthFrequencyMinutes) * time.Minute)
} else { } else {
break break
...@@ -96,16 +99,16 @@ func (bus *MessageBus) ProcessHealthRequest() { ...@@ -96,16 +99,16 @@ func (bus *MessageBus) ProcessHealthRequest() {
for { for {
var request *message.Message var request *message.Message
if msg, err = bus.healthChannel.Recv(); err != nil { if msg, err = bus.healthChannel.Recv(); err != nil {
log.Errorf("cannot recv: %s", err.Error()) log.Errorf("Failed to receive from health channel: %s", err.Error())
continue continue
} }
log.Debugf("Received Core Agent health request %s", string(msg)) log.Debugf("Received health request from core agent %s", string(msg))
if err = json.Unmarshal(msg, &request); err != nil { if err = json.Unmarshal(msg, &request); err != nil {
log.Errorf("failed to unmarshal message: %s", err.Error()) log.Errorf("failed to unmarshal message: %s", err.Error())
continue continue
} }
log.Debugf("unmarshal health request: %v", request)
if request.Topic == message.GetWorkerHealthRequest { if request.Topic == message.GetWorkerHealthRequest {
var result *message.Message var result *message.Message
...@@ -122,7 +125,7 @@ func (bus *MessageBus) ProcessHealthRequest() { ...@@ -122,7 +125,7 @@ func (bus *MessageBus) ProcessHealthRequest() {
continue continue
} }
} else { } else {
log.Infof("Received invalid message, %s", request.Topic) log.Warnf("Received invalid message on health channel, %s", request.Topic)
} }
} }
} }
...@@ -143,7 +146,7 @@ func (bus *MessageBus) ProcessTerminationRequest() { ...@@ -143,7 +146,7 @@ func (bus *MessageBus) ProcessTerminationRequest() {
if bus.terminationChannel.IsConnect() { if bus.terminationChannel.IsConnect() {
if err = bus.terminationChannel.Close(); err != nil { if err = bus.terminationChannel.Close(); err != nil {
bus.context.Log().Errorf("failed to close ipc channel: %v", err) bus.context.Log().Errorf("failed to close termination channel: %v", err)
} }
} }
}() }()
...@@ -153,7 +156,7 @@ func (bus *MessageBus) ProcessTerminationRequest() { ...@@ -153,7 +156,7 @@ func (bus *MessageBus) ProcessTerminationRequest() {
if err = bus.dialToCoreAgentChannel(message.TerminateWorkerRequest, message.TerminationWorkerChannel); err != nil { if err = bus.dialToCoreAgentChannel(message.TerminateWorkerRequest, message.TerminationWorkerChannel); err != nil {
// This happens when worker started before core agent is // This happens when worker started before core agent is
// In practise, it should never happen // In practise, it should never happen
log.Errorf("failed to listen to Core Agent broadcast channel: %s", err.Error()) log.Errorf("failed to listen to termination channel: %s", err.Error())
time.Sleep(time.Duration(bus.context.AppConfig().Ssm.HealthFrequencyMinutes) * time.Minute) time.Sleep(time.Duration(bus.context.AppConfig().Ssm.HealthFrequencyMinutes) * time.Minute)
} else { } else {
break break
...@@ -162,21 +165,22 @@ func (bus *MessageBus) ProcessTerminationRequest() { ...@@ -162,21 +165,22 @@ func (bus *MessageBus) ProcessTerminationRequest() {
} }
log.Infof("Start to listen to Core Agent termination channel") log.Infof("Start to listen to Core Agent termination channel")
bus.terminationChannelConnected <- true
for { for {
var request *message.Message var request *message.Message
if msg, err = bus.terminationChannel.Recv(); err != nil { if msg, err = bus.terminationChannel.Recv(); err != nil {
log.Errorf("cannot recv: %s", err.Error()) log.Errorf("cannot recv: %s", err.Error())
continue continue
} }
log.Infof("Received Core Agent termination request %s", string(msg)) log.Infof("Received termination message from core agent %s", string(msg))
if err = json.Unmarshal(msg, &request); err != nil { if err = json.Unmarshal(msg, &request); err != nil {
log.Errorf("failed to unmarshal message: %s", err.Error()) log.Errorf("failed to unmarshal message: %s", err.Error())
continue continue
} }
log.Debugf("unmarshal health request: %v", request)
if request.Topic == message.TerminateWorkerRequest { if request.Topic == message.TerminateWorkerRequest {
log.Infof("Received Core Agent termination signal, terminating %s", appconfig.SSMAgentWorkerName) log.Debugf("Received termination signal from core agent, terminating %s", appconfig.SSMAgentWorkerName)
var result *message.Message var result *message.Message
if result, err = message.CreateTerminateWorkerResult( if result, err = message.CreateTerminateWorkerResult(
...@@ -184,7 +188,7 @@ func (bus *MessageBus) ProcessTerminationRequest() { ...@@ -184,7 +188,7 @@ func (bus *MessageBus) ProcessTerminationRequest() {
message.LongRunning, message.LongRunning,
os.Getpid(), os.Getpid(),
true); err != nil { true); err != nil {
log.Errorf("failed to create health message: %s", err.Error()) log.Errorf("failed to create termination response: %s", err.Error())
} }
if err = bus.terminationChannel.Send(result); err != nil { if err = bus.terminationChannel.Send(result); err != nil {
...@@ -193,19 +197,14 @@ func (bus *MessageBus) ProcessTerminationRequest() { ...@@ -193,19 +197,14 @@ func (bus *MessageBus) ProcessTerminationRequest() {
} }
// terminating ssm-agent-worker // terminating ssm-agent-worker
bus.rebootRequest <- true bus.terminationRequestChannel <- true
break break
} else { } else {
log.Infof("Received invalid message, %s", request.Topic) log.Warnf("Received invalid message on termination channel, %s", request.Topic)
} }
} }
} }
// RebootRequestChannel returns the reboot request channel
func (bus *MessageBus) RebootRequestChannel() chan bool {
return bus.rebootRequest
}
func (bus *MessageBus) dialToCoreAgentChannel(topic message.TopicType, address string) error { func (bus *MessageBus) dialToCoreAgentChannel(topic message.TopicType, address string) error {
var err error var err error
...@@ -234,3 +233,13 @@ func (bus *MessageBus) dialToCoreAgentChannel(topic message.TopicType, address s ...@@ -234,3 +233,13 @@ func (bus *MessageBus) dialToCoreAgentChannel(topic message.TopicType, address s
return fmt.Errorf("unknown topic type: %s", topic) return fmt.Errorf("unknown topic type: %s", topic)
} }
} }
// GetTerminationRequestChan returns the terminate request channel
func (bus *MessageBus) GetTerminationRequestChan() chan bool {
return bus.terminationRequestChannel
}
// GetTerminationChannelConnectedChan returns the channel notifying when termination channel is connected
func (bus *MessageBus) GetTerminationChannelConnectedChan() chan bool {
return bus.terminationChannelConnected
}
...@@ -54,10 +54,11 @@ func (suite *MessageBusTestSuite) SetupTest() { ...@@ -54,10 +54,11 @@ func (suite *MessageBusTestSuite) SetupTest() {
channels[message.TerminateWorkerRequest] = suite.mockTerminateChannel channels[message.TerminateWorkerRequest] = suite.mockTerminateChannel
suite.messageBus = &MessageBus{ suite.messageBus = &MessageBus{
context: suite.mockContext, context: suite.mockContext,
healthChannel: suite.mockHealthChannel, healthChannel: suite.mockHealthChannel,
terminationChannel: suite.mockTerminateChannel, terminationChannel: suite.mockTerminateChannel,
rebootRequest: make(chan bool, 1), terminationRequestChannel: make(chan bool, 1),
terminationChannelConnected: make(chan bool, 1),
} }
} }
...@@ -81,4 +82,8 @@ func (suite *MessageBusTestSuite) TestProcessTerminationRequest_Successful() { ...@@ -81,4 +82,8 @@ func (suite *MessageBusTestSuite) TestProcessTerminationRequest_Successful() {
suite.messageBus.ProcessTerminationRequest() suite.messageBus.ProcessTerminationRequest()
suite.mockTerminateChannel.AssertExpectations(suite.T()) suite.mockTerminateChannel.AssertExpectations(suite.T())
// Assert termination channel connected and that a termination message is sent
suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationChannelConnectedChan())
suite.Assertions.Equal(true, <-suite.messageBus.GetTerminationRequestChan())
} }
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment