
One small gotcha in AWS Step Functions is that the waitForTaskToken integration pattern only really works within the same AWS account (and region).
You must pass task tokens from principals within the same AWS account. The tokens won’t work if you send them from principals in a different AWS account.
I’ve been prototyping a set of state machines that would need to run ECS tasks in different accounts. In practice, doing this is fairly simple: use an IAM role in the target account as the Credentials in the Task step:
"ExampleStep": {
"Type": "Task",
"Resource": "arn:aws:states:::ecs:runTask.waitForTaskToken",
"Credentials": {
"RoleArn": "aws:arn:iam::TARGET_ACCOUNT_ID:role/TARGET_ROLE_NAME
}
// etc
}
As long as the workload IAM role has permission to assume the role and the target role has a trust relationship that allows the workload IAM role to assume it, this is all good.
Unfortunately trying to SendTask{Heartbeat,Success,Failure} back with the task token from the target account will not work and fails with an invalid task token error.
Using SQS and Lambda as an Inbetween
To work around this, an SQS queue can be used to communicate task status and heartbeats back to the parent account.
This is a very good use case for a FIFO queue: ideally message for each running workflow execution arrive and are processed in order.
Note: my examples here will be terraform.
SQS Queue and Queue Policy
To make this work the queue needs to have a policy that allows the target AWS account to send messages into it. This is best accomplished with AWS organizations and restricting based on aws:PrincipalOrgId:
resource "aws_sqs_queue" "task-status" {
name = "example-task-status.fifo"
message_retention_seconds = 600 # set this to the same as the heartbeat timeout on the Task step
fifo_queue = true
# some high throughput settings
fifo_throughput_limit = "perMessageGroupId"
deduplication_scope = "messageGroup"
receive_wait_time_seconds = 5
# may want a redrive policy, etc
}
data "aws_iam_policy_document" "task-status-queue-policy" {
statement {
actions = [
"sqs:SendMessage",
]
resources = [
aws_sqs_queue.task-status.arn,
]
# allow _any_ principal
principals {
type = "AWS"
identifiers = ["*"]
}
# but restrict to our org
condition {
test = "StringEquals"
variable = "aws:PrincipalOrgID"
values = ["o-ORG_ID_HERE"]
}
}
}
resource "aws_sqs_queue_policy" "task-status" {
queue_url = aws_sqs_queue.task-status.id
policy = data.aws_iam_policy_document.task-status-queue-policy.json
}
SQS Message Format and Identifiers
The idea is that the workload execution name or ARN can be used as the Message Group ID. This will ensure that heartbeats and status updates are processed in the order they are sent.
Each message should have a type as well as a task_token that the consumer can use to update the workload execution.
One subtle gotcha here is that SQS FIFO queues can do content based deduplication or use deduplication IDs. Heartbeat messages — which will always look the same — might get discarded with content based deuplication. Instead, send a message deduplciation ID. The message group ID plus a monotonic clock time can be used to create a deduplication ID. The monotonic time will always increase as the task runs, guaranteeing a unique deduplication ID.
def send_task_status(type_, body=None):
# see below, can make these available as environment variables in an ECS run task call
message_group_id = os.environ['SFN_EXECUTION_ID'] # SFN == Step Functions
queue_url = os.environ['SFN_TASK_STATUS_QUEUE']
task_token = os.envion['SFN_TASK_TOKEN']
body = body or {}
body['type'] = type_
body['task_token'] = task_token
mtime = time.monotonic_ns()
sqs.send_message(
QueueUrl=queue_url,
MessageBody=json.dumps(body),
MessageGroupId=message_group_id
MessageDeduplicationId=f'{message_group_id}@{mtime}',
)
# examples:
send_task_status('heartbeat')
send_task_status('success', {
'output': {'example': 'output passed to step functions'},
})
send_task_status('failure', {
'error': 'SomeErrorCodeHere',
'cause': 'Detailed error description here',
})
The message group ID and queue target can be passed as environment variables in the container overrides when running the task. In terraform syntax:
Work = {
Type = "Task"
Resource = "arn:aws:states:::ecs:runTask.waitForTaskToken"
# ...
Parameters = {
# skipping a bunch of stuff here to show overrides
Overrides = {
ContainerOverrides = [
{
Name = "example"
Environment = [
{
Name = "SFN_TASK_TOKEN"
"Value.$" = "$$.Task.Token"
},
{
Name = "SFN_EXECUTION_ID"
"Value.$" = "$$.Execution.Name"
},
{
Name = "SFN_TASK_STATUS_QUEUE"
Value = aws_sqs_queue.task-status.id
},
]
}
]
}
}
}
Lambda Queue Consumer
With the queue in place, something needs to pull messages out from it and process them.
Lambda seems a bit strange with FIFO queues, namely AWS doesn’t seem to recommend batching messages or sending partial failures to preserve message ordering. I’m not going to show an infra example here since my own is very abstracted due to internal tooling at work. But I will show examples for the Queue Consumer code (in Python):
import json
import logging
import os
import sys
import boto3
TYPE_HEARTBEAT = 'heartbeat'
TYPE_SUCCESS = 'success'
TYPE_FAILURE = 'failure'
class InvalidMessage(Exception):
pass
class InvalidMessageType(InvalidMessage):
pass
class NoTaskToken(InvalidMessage):
pass
def handle_record(sfn, record):
body = json.loads(record['body'])
if 'task_token' not in body:
raise NoTaskToken('task_token was not sent')
type_ = body.get('type')
token = body['task_token']
if type_ == TYPE_HEARTBEAT:
return sfn.send_task_heartbeat(taskToken=token)
if type_ == TYPE_SUCCESS:
output = body.get('output') or {}
return sfn.send_task_success(taskToken=token, output=json.dumps(output))
if type_ == TYPE_FAILURE:
cause = body.get('cause')
error = body.get('error')
return sfn.send_task_failure(taskToken=token, cause=cause, error=error)
raise InvalidMessageType(f'{type_} is not an acceptable record type')
def handle(event, context=None):
logging.debug(event)
sfn = boto3.client('stepfunctions')
records = event.get('Records', [])
for record in records:
try:
result = handle_record(sfn, record)
logging.debug(result)
except InvalidMessage as e:
# we catch and blackhole these because letting them throw
# would retry the message, we don't want that. Invalid messages
# would never work later.
logging.warning(e)
The nice thing about the queue consumer is that it’s mostly stateless. The task token is passed around via the queue messages. The consumer is not tied to any particular backend either, tasks running on another account’s lambda or step functions could send messages back to the parent.