Skip to content

Commit

Permalink
[feat] Add support for self-signed SSL certificates (#3225)
Browse files Browse the repository at this point in the history
  • Loading branch information
mihran113 authored Sep 26, 2024
1 parent a566d4a commit e3cb26c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 9 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Changelog
## Unreleased

### Enhancements
- Add support for self-signed SSL certificates (mihran113)

## 3.24.0 Aug 14, 2024

Expand Down
30 changes: 21 additions & 9 deletions aim/ext/transport/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import logging
import os
import ssl
import threading
import uuid
import weakref
Expand All @@ -10,6 +11,7 @@

import requests

from aim.ext.transport.config import AIM_CLIENT_SSL_CERTIFICATES_FILE
from aim.ext.transport.heartbeat import HeartbeatSender
from aim.ext.transport.message_utils import (
decode_tree,
Expand Down Expand Up @@ -43,6 +45,13 @@ def __init__(self, remote_path: str):
self._http_protocol = 'http://'
self._ws_protocol = 'ws://'
self.request_headers = {}

self.ssl_certfile = os.getenv(AIM_CLIENT_SSL_CERTIFICATES_FILE)
self.ssl_context = None
if self.ssl_certfile:
self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
self.ssl_context.load_cert_chain(certfile=self.ssl_certfile)

self.protocol_probe()

self._resource_pool = weakref.WeakValueDictionary()
Expand Down Expand Up @@ -76,7 +85,7 @@ def protocol_probe(self):

endpoint = f'https://{self.remote_path}/status/'
try:
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
if response.status_code == 200:
self._http_protocol = 'https://'
self._ws_protocol = 'wss://'
Expand Down Expand Up @@ -132,7 +141,7 @@ def _check_remote_version_compatibility(self):

def client_heartbeat(self):
endpoint = f'{self._http_protocol}{self._client_endpoint}/heartbeat/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -145,7 +154,7 @@ def client_heartbeat(self):
)
def connect(self):
endpoint = f'{self._http_protocol}{self._client_endpoint}/connect/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -154,7 +163,7 @@ def connect(self):

def reconnect(self):
endpoint = f'{self._http_protocol}{self._client_endpoint}/reconnect/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -170,7 +179,7 @@ def disconnect(self):
self._ws.close()

endpoint = f'{self._http_protocol}{self._client_endpoint}/disconnect/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -181,7 +190,7 @@ def get_version(
self,
):
endpoint = f'{self._http_protocol}{self._client_endpoint}/get-version/'
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code == 404:
return '<3.19.0'
Expand All @@ -198,7 +207,7 @@ def get_resource_handler(self, resource, resource_type, handler='', args=()):
'args': base64.b64encode(args).decode(),
}

response = requests.post(endpoint, json=request_data, headers=self.request_headers)
response = requests.post(endpoint, json=request_data, headers=self.request_headers, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code == 400:
raise_exception(response_json.get('exception'))
Expand All @@ -215,7 +224,7 @@ def release_resource(self, queue_id, resource_handler):
if queue_id != -1:
self.get_queue().wait_for_finish()

response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code == 400:
raise_exception(response_json.get('exception'))
Expand Down Expand Up @@ -255,7 +264,9 @@ def _run_read_instructions(self, queue_id, resource, method, args):
if queue_id != -1:
self.get_queue().wait_for_finish()

response = requests.post(endpoint, json=request_data, stream=True, headers=self.request_headers)
response = requests.post(
endpoint, json=request_data, stream=True, headers=self.request_headers, verify=self.ssl_certfile
)

if response.status_code == 400:
raise_exception(response.json().get('exception'))
Expand Down Expand Up @@ -297,6 +308,7 @@ def ws(self):
f'{self._ws_protocol}{self._tracking_endpoint}/{self.uri}/write-instruction/',
additional_headers=self.request_headers,
max_size=None,
ssl_context=self.ssl_context,
)

return self._ws
Expand Down
2 changes: 2 additions & 0 deletions aim/ext/transport/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
AIM_SERVER_BASE_PATH = '__AIM_SERVER_BASE_PATH__'

AIM_RT_BEARER_TOKEN = '__AIM_RT_BEARER_TOKEN__'

AIM_CLIENT_SSL_CERTIFICATES_FILE = '__AIM_CLIENT_SSL_CERTIFICATES_FILE__'
20 changes: 20 additions & 0 deletions docs/source/using/remote_tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,26 @@ where `--ssl-keyfile` is the path to the private key file for the certificate an
(check out the [Aim CLI](../refs/cli.html#server) here).


If you're using self-signed certificates the client has to be configured accordingly, otherwise the client will automatically detect if the server supports secure connection or not.
Please set __AIM_CLIENT_SSL_CERTIFICATES_FILE__ environment variable to the file, where PEM encoded root certificates are located, e.g.

```shell
export __AIM_CLIENT_SSL_CERTIFICATES_FILE__=/path/of/the/certs/file
```

__*Note:*__
For the sake of convenience of providing only one file to the client, the private key will be taken from certfile as well.
The following example will do the trick:

```shell
# generate the cert and key files
openssl genrsa -out server.key 2048
openssl req -new -x509 -sha256 -key server.key -out server.crt -days 3650 -subj '/CN={DOMAIN_NAME}'
# append the private key to the certs in a new file
cat server.crt server.key > server.includesprivatekey.pem
# set the env variable for aim client
export __AIM_CLIENT_SSL_CERTIFICATES_FILE__=./server.includesprivatekey.pem
```
### Conclusion

As you can see, aim remote tracking server allows running experiments on multiple hosts with simple setup and
Expand Down

0 comments on commit e3cb26c

Please sign in to comment.