16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192 | class SagemakerServer(BaseServerObject):
"""
Serve an Expert Model using AWS Sagemaker Inference.
Expert models operate independently and can be hosted on different or common machines.
They may include closed-source models like GPT-4. Refer to (app/configs/demo_orch_ec2_mix.json) for configuration details.
"""
def __init__(self,**kwargs):
"""
"""
self.config = kwargs
os.environ['AWS_DEFAULT_REGION'] = self.config.get("region", "us-east-1")
self.engine = AwsSagemakerEngine(self.config)
self.model_id = self.config['model_id']
self.instance_name = self.model_id
self._set_inferece_params()
self.config.update(dict(logs=[]))
def _set_inferece_params(self):
self.inference_params = {
"top_p": self.config.get("top_p", 0.6),
"top_k": self.config.get("top_k", 50),
"stop": self.config.get("stop", ["</s>"]),
"do_sample": self.config.get("do_sample", True),
"temperature": self.config.get("temperature", 0.9),
"max_new_tokens": self.config.get("max_new_tokens", 512),
"return_full_text": self.config.get("return_full_text", False),
"repetition_penalty": self.config.get("repetition_penalty", 1.03)
}
def get_endpoint_name(self):
"""HACK : endpoint names with more than 2 '-' are not supported 28jan2024.
We use model_id for creating endpoint name and hence trim the later"""
endpoint_name = "-".join(self.config['model_id'].split("-")[0:2])
endpoint_name = f"{endpoint_name.split('/')[1]}-tgi-streaming"
return endpoint_name
def start_server(self):
"""Since aws sagemaker deployment dosent require this step, 'start_inference_endpoint' covers everything.
"""
self.config['logs'].append(f'Sagemaker server {self.model_id}')
def start_inference_endpoint(self, max_wait_time=600):
"""
Starts an AWS SageMaker Endpoint using the 'instance_type' from the configuration.
The 'get_endpoint_name' method returns a unique identifier for the expert. If an inference endpoint with the same name is already present in the provided region, the operation is aborted.
NOTE: All models available on Hugging Face can be served using AWS SageMaker.
Args:
max_wait_time (int, optional): Defaults to 600.
"""
"""Check if inference endpoint is already present"""
is_present = False
instance_meta = {}
endpoints = self.engine.sagemaker_client.list_endpoints()
model_id_trimmed = self.get_endpoint_name()
for endpoint in endpoints['Endpoints']:
if model_id_trimmed in endpoint['EndpointName'] and \
endpoint['EndpointStatus'] in ['Creating', 'InService']:
"""TODO the inference endpoint matching strategy is using the model_id
In case of multiple experts with same model_id assign tags during endpoint creation
and use that as a filter here"""
is_present = True
self.endpoint_name = endpoint['EndpointName']
instance_meta = dict(
is_present=is_present,
ip_address="sagemaker-endpoint",
instance_name=self.endpoint_name,
endpoint_name=self.endpoint_name
)
self.config.update(instance_meta)
if is_present:
return
"""Create a new inference endpoint"""
llm_image = get_hf_image(self.config.get("region", "us-east-1"))
self.llm_image = llm_image.__str__()
# sagemaker config
instance_type = self.config['instance_type']
number_of_gpu = self.config.get("number_of_gpu", 1)
# Define Model and Endpoint configuration parameters
config = {
'HF_MODEL_ID': self.model_id, # model_id from hf.co/models
'SM_NUM_GPUS': json.dumps(number_of_gpu), # Number of GPU used per replica
'MAX_INPUT_LENGTH': json.dumps(self.config.get("max_input_length", 2048)), # Max length of input text
'MAX_TOTAL_TOKENS': json.dumps(self.config.get("max_total_length",4096)), # Max length of the generation (including input text)
'MAX_BATCH_TOTAL_TOKENS': json.dumps(self.config.get("max_batch_total_tokens",8192)), # Limits the number of tokens that can be processed in parallel during the generation
'HUGGING_FACE_HUB_TOKEN': os.getenv("HUGGING_FACE_HUB_TOKEN") # Read Access token of your HuggingFace profile https://huggingface.co/settings/tokens
}
# create HuggingFaceModel with the image uri
llm_model = HuggingFaceModel(
role=self.engine.role,
image_uri=self.llm_image,
env=config
)
endpoint_name = self.get_endpoint_name()
self.endpoint_name = name_from_base(endpoint_name)
print(self.endpoint_name)
llm = llm_model.deploy(
endpoint_name=self.endpoint_name,
initial_instance_count=1,
instance_type=instance_type,
wait=False, # Whether the call should wait until the deployment of this model completes
container_startup_health_check_timeout=max_wait_time,
)
instance_meta = dict(
ip_address="sagemaker-endpoint",
instance_name=self.endpoint_name,
endpoint_name=self.endpoint_name
)
self.config.update(instance_meta)
self.config['logs'].append(f'Sagemaker inference endpoint {self.model_id}')
def stop_server(self):
"""Terminate the Sagemaker Endpoint
"""
endpoint = self.engine.sagemaker_client.describe_endpoint(
EndpointName=self.endpoint_name)
endpoint_config_name = endpoint['EndpointConfigName']
endpoint_config = self.engine.sagemaker_client.describe_endpoint_config(
EndpointConfigName=endpoint_config_name)
model_name = endpoint_config['ProductionVariants'][0]['ModelName']
print(f"""
About to delete the following sagemaker resources:
Endpoint: {self.endpoint_name}
Endpoint Config: {endpoint_config_name}
Model: {model_name}
""")
# delete endpoint
self.engine.sagemaker_client.delete_endpoint(EndpointName=self.endpoint_name)
# delete endpoint config
self.engine.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
# delete model
self.engine.sagemaker_client.delete_model(ModelName=model_name)
def _get_inference_payload(self, message, stream=True):
payload = {
"inputs": message,
"parameters": self.inference_params,
"stream": stream
}
return payload
def check_servers_state(self):
payload = self._get_inference_payload("hello!", stream=True)
try:
resp = get_realtime_response_stream(
self.engine.sagemaker_runtime,
self.endpoint_name,
payload
)
except:
return (False, '')
# print_response_stream(resp)
return (True, 'running')
def get_response(self, message, stream=True, verbose=True):
payload = self._get_inference_payload(message, stream)
resp = get_realtime_response_stream(
self.engine.sagemaker_runtime,
self.endpoint_name,
payload
)
if stream:
text = parse_response_stream(resp, verbose)
else: # TODO parse the response generated when streaming is false
raise NotImplementedError
return text
|