AWS Websocket Step Function Notes

AWSTemplateFormatVersion: '2010-09-09'
Description: A stack that creates the resources required to complete the Amazon API Gateway WebSocket tutorial.

Resources:
  ConnectionsTable:
    Type: AWS::DynamoDB::Table
    Properties:
      KeySchema:
        - AttributeName: connectionId
          KeyType: HASH
      AttributeDefinitions:
        - AttributeName: connectionId
          AttributeType: S
      ProvisionedThroughput:
        ReadCapacityUnits: 5
        WriteCapacityUnits: 5
    UpdateReplacePolicy: Delete
    DeletionPolicy: Delete
  ConnectHandlerServiceRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Statement:
          - Action: sts:AssumeRole
            Effect: Allow
            Principal:
              Service: lambda.amazonaws.com
        Version: "2012-10-17"
      ManagedPolicyArns:
        - Fn::Join:
            - ""
            - - "arn:"
              - Ref: AWS::Partition
              - :iam::aws:policy/service-role/AWSLambdaBasicExecutionRole
  ConnectHandlerServiceRoleDefaultPolicy:
    Type: AWS::IAM::Policy
    Properties:
      PolicyDocument:
        Statement:
          - Action:
              - dynamodb:BatchWriteItem
              - dynamodb:PutItem
              - dynamodb:UpdateItem
              - dynamodb:DeleteItem
              - dynamodb:DescribeTable
            Effect: Allow
            Resource:
              - Fn::GetAtt:
                  - ConnectionsTable
                  - Arn
              - Ref: AWS::NoValue
        Version: "2012-10-17"
      PolicyName: ConnectHandlerServiceRoleDefaultPolicy
      Roles:
        - Ref: ConnectHandlerServiceRole
  ConnectHandler:
    Type: AWS::Lambda::Function
    Properties:
      Code:
        ZipFile: |-
          import json
          import logging
          import os
          import boto3
          from botocore.exceptions import ClientError

          logger = logging.getLogger()
          logger.setLevel("INFO")

          ddb_client = boto3.client('dynamodb')

          def lambda_handler(event, context):
              try:
                  item = {
                      'connectionId': {
                          'S': event['requestContext']['connectionId']
                      },
                      'principalId': {
                          'S': event['requestContext']['authorizer']['principalId']
                      }
                  }


                  put_item(
                      table_name=os.environ['TABLE_NAME'],
                      item=item
                  )
              except Exception as e:
                  logger.error("Something went wrong with putting the connection ID into the table! Here's what: %s", e)
                  return {
                      'statusCode': 500
                  }
              
              return {
                  'statusCode': 200
              }

          def put_item(table_name, item):
              try:
                  ddb_client.put_item(
                      TableName=table_name,
                      Item=item
                  )
                  logger.info(
                      "Connection ID added to table: %s",
                      json.dumps(item)
                  )
              except ClientError as err:
                  logger.error(
                      "Couldn't add item %s to table %s. Here's why: %s: %s",
                      json.dumps(item),
                      table_name,
                      err.response["Error"]["Code"],
                      err.response["Error"]["Message"],
                  )
                  raise
      Role:
        Fn::GetAtt:
          - ConnectHandlerServiceRole
          - Arn
      Environment:
        Variables:
          TABLE_NAME:
            Ref: ConnectionsTable
      Handler: index.lambda_handler
      Timeout: 5
      Runtime: python3.12
    DependsOn:
      - ConnectHandlerServiceRoleDefaultPolicy
      - ConnectHandlerServiceRole
  DisconnectHandlerServiceRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Statement:
          - Action: sts:AssumeRole
            Effect: Allow
            Principal:
              Service: lambda.amazonaws.com
        Version: "2012-10-17"
      ManagedPolicyArns:
        - Fn::Join:
            - ""
            - - "arn:"
              - Ref: AWS::Partition
              - :iam::aws:policy/service-role/AWSLambdaBasicExecutionRole
  DisconnectHandlerServiceRoleDefaultPolicy:
    Type: AWS::IAM::Policy
    Properties:
      PolicyDocument:
        Statement:
          - Action:
              - dynamodb:BatchWriteItem
              - dynamodb:PutItem
              - dynamodb:UpdateItem
              - dynamodb:DeleteItem
              - dynamodb:DescribeTable
            Effect: Allow
            Resource:
              - Fn::GetAtt:
                  - ConnectionsTable
                  - Arn
              - Ref: AWS::NoValue
        Version: "2012-10-17"
      PolicyName: DisconnectHandlerServiceRoleDefaultPolicy
      Roles:
        - Ref: DisconnectHandlerServiceRole
  DisconnectHandler:
    Type: AWS::Lambda::Function
    Properties:
      Code:
        ZipFile: |-
          import json
          import logging
          import os
          import boto3
          from botocore.exceptions import ClientError

          logger = logging.getLogger()
          logger.setLevel("INFO")

          ddb_client = boto3.client('dynamodb')

          def lambda_handler(event, context):
              try:
                  item = {
                      'connectionId': {
                          'S': event['requestContext']['connectionId']
                      }
                  }
                  
                  delete_item(
                      table_name=os.environ['TABLE_NAME'],
                      item=item
                  )
              except Exception as e:
                  logger.error("Something went wrong with deleting the item from the table! Here's what: %s", e)
                  return {
                      'statusCode': 500
                  }
              
              return {
                  'statusCode': 200
              }

          def delete_item(table_name, item):
              try:
                  ddb_client.delete_item(
                      TableName=table_name,
                      Key=item
                  )
                  logger.info(
                      "Connection ID removed from table: %s",
                      json.dumps(item)
                  )
              except ClientError as err:
                  logger.error(
                      "Couldn't delete item %s from table %s. Here's why: %s: %s",
                      json.dumps(item),
                      table_name,
                      err.response["Error"]["Code"],
                      err.response["Error"]["Message"],
                  )
                  raise
      Role:
        Fn::GetAtt:
          - DisconnectHandlerServiceRole
          - Arn
      Environment:
        Variables:
          TABLE_NAME:
            Ref: ConnectionsTable
      Handler: index.lambda_handler
      Runtime: python3.12
    DependsOn:
      - DisconnectHandlerServiceRoleDefaultPolicy
      - DisconnectHandlerServiceRole
  SendMessageHandlerServiceRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Statement:
          - Action: sts:AssumeRole
            Effect: Allow
            Principal:
              Service: lambda.amazonaws.com
        Version: "2012-10-17"
      ManagedPolicyArns:
        - Fn::Join:
            - ""
            - - "arn:"
              - Ref: AWS::Partition
              - :iam::aws:policy/service-role/AWSLambdaBasicExecutionRole
  SendMessageHandlerServiceRoleDefaultPolicy:
    Type: AWS::IAM::Policy
    Properties:
      PolicyDocument:
        Statement:
          - Action:
              - dynamodb:BatchGetItem
              - dynamodb:GetRecords
              - dynamodb:GetShardIterator
              - dynamodb:Query
              - dynamodb:GetItem
              - dynamodb:Scan
              - dynamodb:ConditionCheckItem
              - dynamodb:DescribeTable
            Effect: Allow
            Resource:
              - Fn::GetAtt:
                  - ConnectionsTable
                  - Arn
              - Ref: AWS::NoValue
        Version: "2012-10-17"
      PolicyName: SendMessageHandlerServiceRoleDefaultPolicy
      Roles:
        - Ref: SendMessageHandlerServiceRole
  SendMessageHandler:
    Type: AWS::Lambda::Function
    Properties:
      Code:
        ZipFile: |-
          import json
          import logging
          import os
          import boto3
          from botocore.exceptions import ClientError

          logger = logging.getLogger()
          logger.setLevel("INFO")

          ddb_client = boto3.client('dynamodb')

          def lambda_handler(event, context):
              endpoint_url = "https://" + "/".join([
                  event['domain'], 
                  event['stage']
              ])
              try:
                  connection_ids = scan_table(os.environ['TABLE_NAME'])
              except Exception as e:
                  logger.error("Scanning the table for connection IDs failed! Here's why: %s", e)
                  return {
                      'statusCode': 500
                  }  
              
              apigateway_client = boto3.client(
                  'apigatewaymanagementapi',
                  endpoint_url=endpoint_url
              )
              
              for connection_id in connection_ids['Items']:
                  try:
                      send_message(
                          apigateway_client,
                          connection_id['connectionId']['S'],
                          event['message']
                      )
                  except Exception as e:
                      logger.error("Sending a message to connection ID: %s failed! Here's why: %s", connection_id, e)
              
              return {
                  'statusCode': 200
              }

          def scan_table(table_name):
              try:
                  response = ddb_client.scan(
                      TableName=table_name,
                  )
                  logger.info(
                      "Table scanned successfully: %s",
                      json.dumps(table_name)
                  )
                  
                  return response
              except ClientError as err:
                  logger.error(
                      "Couldn't scan table %s. Here's why: %s: %s",
                      table_name,
                      err.response["Error"]["Code"],
                      err.response["Error"]["Message"],
                  )
                  raise
              
          def send_message(apigateway_client, connection_id, message):
              try:
                  response = apigateway_client.post_to_connection(
                      Data=message.encode('utf-8'),
                      ConnectionId=connection_id
                  )
                  logger.info("Message successfully sent: %s", response)
              except ClientError as err:
                  logger.error(
                      "Couldn't send message to client: %s. Here's why: %s: %s",
                      connection_id,
                      err.response["Error"]["Code"],
                      err.response["Error"]["Message"],
                  )
                  raise
      Role:
        Fn::GetAtt:
          - SendMessageHandlerServiceRole
          - Arn
      Environment:
        Variables:
          TABLE_NAME:
            Ref: ConnectionsTable
      Handler: index.lambda_handler
      Runtime: python3.12
    DependsOn:
      - SendMessageHandlerServiceRoleDefaultPolicy
      - SendMessageHandlerServiceRole
  manageConnections:
    Type: AWS::IAM::Policy
    Properties:
      PolicyDocument:
        Statement:
          - Action: execute-api:ManageConnections
            Effect: Allow
            Resource:
              Fn::Join:
                - ""
                - - "arn:aws:execute-api:"
                  - Ref: AWS::Region
                  - ":"
                  - Ref: AWS::AccountId
                  - ":"
                  - "*/*/POST/@connections/*"
        Version: "2012-10-17"
      PolicyName: manageConnections7F91357B
      Roles:
        - Ref: SendMessageHandlerServiceRole

  AuthorizerServiceRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Statement:
          - Action: sts:AssumeRole
            Effect: Allow
            Principal:
              Service: lambda.amazonaws.com
        Version: "2012-10-17"
      ManagedPolicyArns:
        - Fn::Join:
            - ""
            - - "arn:"
              - Ref: AWS::Partition
              - :iam::aws:policy/service-role/AWSLambdaBasicExecutionRole

  Authorizer:
    Type: AWS::Lambda::Function
    Properties:
      Code:
        ZipFile: |-
          import json
          import logging
          import os

          logger = logging.getLogger()
          logger.setLevel("INFO")


          def lambda_handler(event, context):
              logger.info(event)
              token = "Allow"

              if event['headers']['Authorization'] == token:
                  response = generate_allow('ws-tut-user', event['methodArn'])
                  logger.info('authorized')
                  return json.loads(response)
              else:
                  response = generate_deny('ws-tut-user', event['methodArn'])
                  logger.error('unauthorized')
                  return json.loads(response)


          def generate_policy(principal_id, effect, resource):
              auth_response = {}
              auth_response['principalId'] = principal_id
              auth_response['context'] = {
                "principalId": principal_id
              } 

              if (effect and resource):
                  policy_document = {}
                  policy_document['Version'] = '2012-10-17'
                  policy_document['Statement'] = []
                  statement = {}
                  statement['Action'] = 'execute-api:Invoke'
                  statement['Effect'] = effect
                  statement['Resource'] = resource
                  policy_document['Statement'] = [statement]
                  auth_response['policyDocument'] = policy_document

              auth_response_json = json.dumps(auth_response)
              return auth_response_json


          def generate_allow(principal_id, resource):
              return generate_policy(principal_id, 'Allow', resource)


          def generate_deny(principal_id, resource):
              return generate_policy(principal_id, 'Deny', resource)
      Role:
        Fn::GetAtt:
          - AuthorizerServiceRole
          - Arn
      Handler: index.lambda_handler
      Runtime: python3.12
    DependsOn:
      - AuthorizerServiceRole

  StateMachine:
    Type: AWS::StepFunctions::StateMachine
    Properties:
      StateMachineName: WebSocket-Tutorial-StateMachine
      DefinitionString: !Sub |
        {
          "Comment": "WebSocket tutorial state machine",
          "StartAt": "Send Message Lambda",
          "States": {
            "Send Message Lambda": {
              "Type": "Task",
              "Resource": "arn:aws:states:::lambda:invoke",
              "OutputPath": "$.Payload",
              "Parameters": {
                "Payload.$": "$",
                "FunctionName": "${SendMessageHandler.Arn}"
              },
              "Retry": [
                {
                  "ErrorEquals": [
                    "States.ALL"
                  ],
                  "IntervalSeconds": 1,
                  "MaxAttempts": 3,
                  "BackoffRate": 2
                }
              ],
              "End": true
            }
          },
          "TimeoutSeconds": 600
        }
      RoleArn: !GetAtt 'StateMachineRole.Arn'

  StateMachineRole:
    Type: 'AWS::IAM::Role'
    Properties:
      AssumeRolePolicyDocument:
        Version: '2012-10-17'
        Statement:
        - Effect: Allow
          Principal:
            Service: !Sub 'states.${AWS::Region}.amazonaws.com'
          Action: 'sts:AssumeRole'
      Policies:
      - PolicyName: ws-tut-send-message-sfn
        PolicyDocument:
          Statement:
          - Effect: Allow
            Action: 'lambda:InvokeFunction'
            Resource:
            - !GetAtt 'SendMessageHandler.Arn'
            
  ApiGatewayRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Version: '2012-10-17'
        Statement:
          - Effect: Allow
            Principal:
              Service:
                - apigateway.amazonaws.com
            Action:
              - sts:AssumeRole
      RoleName: "WebsocketTutorialApiRole"
      Policies:
        - PolicyName: ApiGatewayLogsPolicy
          PolicyDocument:
            Version: '2012-10-17'
            Statement:
              - Action: states:StartExecution
                Effect: Allow
                Resource: !GetAtt 'StateMachine.Arn'

Outputs:
  ApiGatewayRole:
    Description: Role to use with API created in the tutorial that integrates with Step Functions
    Value: !GetAtt 'ApiGatewayRole.Arn'
  ConnectHandlerFunction:
    Description: Lambda function for the $connect route of the WebSocket API
    Value: !GetAtt 'ConnectHandler.Arn'
  DisconnectHandlerFunction:
    Description: Lambda function for the $disconnect route of the WebSocket API
    Value: !GetAtt 'DisconnectHandler.Arn'
  AuthorizerFunction:
    Description: Lambda function for the Lambda Authorizer of the WebSocket API
    Value: !GetAtt 'Authorizer.Arn'
  StateMachine:
    Description: Step Functions state machine that is executed on the sendMessage route
    Value: !GetAtt 'StateMachine.Arn'