Skip to content

Sagemaker Utils

AwsSagemakerEngine

AwsEngine uses boto3 for sagemaker connections

Source code in app/utils/aws_sagemaker_utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class AwsSagemakerEngine:
    """AwsEngine uses boto3 for sagemaker connections
    """
    def __init__(self,config):
        self.iam = boto3.client(
            'iam',
            aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
            aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
            region_name=config.get("region", "us-east-1")
        )
        self.role = self.iam.get_role(
            RoleName=os.getenv('AWS_SAGEMAKER_ROLE_NAME'))['Role']['Arn']

        self.sagemaker_runtime = boto3.client(
            'sagemaker-runtime',         
            aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
            aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
            region_name=config.get("region", "us-east-1")
        )
        self.sagemaker_client = boto3.client(
            'sagemaker',
            aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
            aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
            region_name=config.get("region", "us-east-1")
        )

LineIterator

A helper class for parsing the byte stream input.

The output of the model will be in the following format:

{"outputs": [" a"]}   
{"outputs": [" challenging"]}   
{"outputs": [" problem"]}   
...

While usually each PayloadPart event from the event stream will contain a byte array with a full json, this is not guaranteed and some of the json objects may be split across PayloadPart events. For example:

{'PayloadPart': {'Bytes': {"outputs": '}}   
{'PayloadPart': {'Bytes': [" problem"]}'}}

This class accounts for this by concatenating bytes written via the 'write' function and then exposing a method which will return lines (ending with a '\n' character) within the buffer via the 'scan_lines' function. It maintains the position of the last read position to ensure that previous bytes are not exposed again.

Source code in app/utils/aws_sagemaker_utils.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class LineIterator:
    """
    A helper class for parsing the byte stream input. 

    The output of the model will be in the following format:

        {"outputs": [" a"]}   
        {"outputs": [" challenging"]}   
        {"outputs": [" problem"]}   
        ...


    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:

        {'PayloadPart': {'Bytes': {"outputs": '}}   
        {'PayloadPart': {'Bytes': [" problem"]}'}}   


    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """

    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload)

Fetch Streaming Text Generation response from AWS Sagemaker Enpoint.

Parameters:

Name Type Description Default
sagemaker_runtime sagemaker_runtime
required
endpoint_name str

Endpoint Name.

required
payload json

Request

required

Returns:

Name Type Description
_type_

description

Source code in app/utils/aws_sagemaker_utils.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload):
    """Fetch Streaming Text Generation response from AWS Sagemaker Enpoint.  

    Args:
        sagemaker_runtime (AwsSagemakerEngine.sagemaker_runtime):   
        endpoint_name (str): Endpoint Name.
        payload (json): Request

    Returns:
        _type_: _description_
    """
    response_stream = sagemaker_runtime.invoke_endpoint_with_response_stream(
        EndpointName=endpoint_name,
        Body=json.dumps(payload), 
        ContentType="application/json",
        CustomAttributes='accept_eula=true'
    )
    return response_stream

parse_response_stream(response_stream, verbose=True)

Prints Streaming Text

Parameters:

Name Type Description Default
response_stream

response stream

required
verbose bool

description. Defaults to True.

True

Returns:

Name Type Description
resonse str

Generated Text.

Source code in app/utils/aws_sagemaker_utils.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def parse_response_stream(response_stream, verbose=True):
    """Prints Streaming Text

    Args:
        response_stream : response stream
        verbose (bool, optional): _description_. Defaults to True.

    Returns:
        resonse (str): Generated Text.
    """
    response_text = ""
    event_stream = response_stream['Body']
    start_json = b'{'
    stop_token = '</s>'
    for line in LineIterator(event_stream):
        if line != b'' and start_json in line:
            data = json.loads(line[line.find(start_json):].decode('utf-8'))
            if data['token']['text'] != stop_token:
                response_text += data['token']['text']
                if verbose: print(data['token']['text'],end='')
    return response_text