Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add support for self-signed SSL certificates #3225

Merged
merged 4 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Changelog
## Unreleased

### Enhancements
- Add support for self-signed SSL certificates (mihran113)

## 3.24.0 Aug 14, 2024

Expand Down
30 changes: 21 additions & 9 deletions aim/ext/transport/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import logging
import os
import ssl
import threading
import uuid
import weakref
Expand All @@ -10,6 +11,7 @@

import requests

from aim.ext.transport.config import AIM_CLIENT_SSL_CERTIFICATES_FILE
from aim.ext.transport.heartbeat import HeartbeatSender
from aim.ext.transport.message_utils import (
decode_tree,
Expand Down Expand Up @@ -43,6 +45,13 @@ def __init__(self, remote_path: str):
self._http_protocol = 'http://'
self._ws_protocol = 'ws://'
self.request_headers = {}

self.ssl_certfile = os.getenv(AIM_CLIENT_SSL_CERTIFICATES_FILE)
self.ssl_context = None
if self.ssl_certfile:
self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
self.ssl_context.load_cert_chain(certfile=self.ssl_certfile)

self.protocol_probe()

self._resource_pool = weakref.WeakValueDictionary()
Expand Down Expand Up @@ -76,7 +85,7 @@ def protocol_probe(self):

endpoint = f'https://{self.remote_path}/status/'
try:
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
if response.status_code == 200:
self._http_protocol = 'https://'
self._ws_protocol = 'wss://'
Expand Down Expand Up @@ -132,7 +141,7 @@ def _check_remote_version_compatibility(self):

def client_heartbeat(self):
endpoint = f'{self._http_protocol}{self._client_endpoint}/heartbeat/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -145,7 +154,7 @@ def client_heartbeat(self):
)
def connect(self):
endpoint = f'{self._http_protocol}{self._client_endpoint}/connect/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -154,7 +163,7 @@ def connect(self):

def reconnect(self):
endpoint = f'{self._http_protocol}{self._client_endpoint}/reconnect/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -170,7 +179,7 @@ def disconnect(self):
self._ws.close()

endpoint = f'{self._http_protocol}{self._client_endpoint}/disconnect/{self.uri}/'
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code != 200:
raise_exception(response_json.get('message'))
Expand All @@ -181,7 +190,7 @@ def get_version(
self,
):
endpoint = f'{self._http_protocol}{self._client_endpoint}/get-version/'
response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code == 404:
return '<3.19.0'
Expand All @@ -198,7 +207,7 @@ def get_resource_handler(self, resource, resource_type, handler='', args=()):
'args': base64.b64encode(args).decode(),
}

response = requests.post(endpoint, json=request_data, headers=self.request_headers)
response = requests.post(endpoint, json=request_data, headers=self.request_headers, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code == 400:
raise_exception(response_json.get('exception'))
Expand All @@ -215,7 +224,7 @@ def release_resource(self, queue_id, resource_handler):
if queue_id != -1:
self.get_queue().wait_for_finish()

response = requests.get(endpoint, headers=self.request_headers, timeout=10)
response = requests.get(endpoint, headers=self.request_headers, timeout=10, verify=self.ssl_certfile)
response_json = response.json()
if response.status_code == 400:
raise_exception(response_json.get('exception'))
Expand Down Expand Up @@ -255,7 +264,9 @@ def _run_read_instructions(self, queue_id, resource, method, args):
if queue_id != -1:
self.get_queue().wait_for_finish()

response = requests.post(endpoint, json=request_data, stream=True, headers=self.request_headers)
response = requests.post(
endpoint, json=request_data, stream=True, headers=self.request_headers, verify=self.ssl_certfile
)

if response.status_code == 400:
raise_exception(response.json().get('exception'))
Expand Down Expand Up @@ -297,6 +308,7 @@ def ws(self):
f'{self._ws_protocol}{self._tracking_endpoint}/{self.uri}/write-instruction/',
additional_headers=self.request_headers,
max_size=None,
ssl_context=self.ssl_context,
)

return self._ws
Expand Down
2 changes: 2 additions & 0 deletions aim/ext/transport/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
AIM_SERVER_BASE_PATH = '__AIM_SERVER_BASE_PATH__'

AIM_RT_BEARER_TOKEN = '__AIM_RT_BEARER_TOKEN__'

AIM_CLIENT_SSL_CERTIFICATES_FILE = '__AIM_CLIENT_SSL_CERTIFICATES_FILE__'
20 changes: 20 additions & 0 deletions docs/source/using/remote_tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,26 @@ where `--ssl-keyfile` is the path to the private key file for the certificate an
(check out the [Aim CLI](../refs/cli.html#server) here).


If you're using self-signed certificates the client has to be configured accordingly, otherwise the client will automatically detect if the server supports secure connection or not.
Please set __AIM_CLIENT_SSL_CERTIFICATES_FILE__ environment variable to the file, where PEM encoded root certificates are located, e.g.

```shell
export __AIM_CLIENT_SSL_CERTIFICATES_FILE__=/path/of/the/certs/file
```

__*Note:*__
For the sake of convenience of providing only one file to the client, the private key will be taken from certfile as well.
The following example will do the trick:

```shell
# generate the cert and key files
openssl genrsa -out server.key 2048
openssl req -new -x509 -sha256 -key server.key -out server.crt -days 3650 -subj '/CN={DOMAIN_NAME}'
# append the private key to the certs in a new file
cat server.crt server.key > server.includesprivatekey.pem
# set the env variable for aim client
export __AIM_CLIENT_SSL_CERTIFICATES_FILE__=./server.includesprivatekey.pem
```
### Conclusion

As you can see, aim remote tracking server allows running experiments on multiple hosts with simple setup and
Expand Down
Loading