diff --git a/geniusrise_vision/base/api.py b/geniusrise_vision/base/api.py index 1b164b6..c3ebcb7 100644 --- a/geniusrise_vision/base/api.py +++ b/geniusrise_vision/base/api.py @@ -28,16 +28,6 @@ sequential_lock = threading.Lock() -def sequential_tool(): - with sequential_lock: - # Yield to signal that the request can proceed - yield - - -# Register the custom tool -cherrypy.tools.sequential = cherrypy.Tool("before_handler", sequential_tool) - - class VisionAPI(VisionBulk): """ The VisionAPI class inherits from VisionBulk and is designed to facilitate @@ -93,6 +83,7 @@ def listen( compile: bool = False, flash_attention: bool = False, better_transformers: bool = False, + concurrent_queries: bool = False, endpoint: str = "*", port: int = 3000, cors_domain: str = "http://localhost:3000", @@ -115,6 +106,7 @@ def listen( compile (bool, optional): Whether to compile the model before fine-tuning. Defaults to False. flash_attention (bool): Whether to use flash attention 2. Default is False. better_transformers (bool): Flag to enable Better Transformers optimization for faster processing. + concurrent_queries: (bool): Whether the API supports concurrent API calls (usually false). endpoint (str, optional): The network endpoint for the server. Defaults to "*". port (int, optional): The network port for the server. Defaults to 3000. cors_domain (str, optional): The domain to allow for CORS requests. Defaults to "http://localhost:3000". @@ -134,6 +126,7 @@ def listen( self.compile = compile self.flash_attention = flash_attention self.better_transformers = better_transformers + self.concurrent_queries = concurrent_queries self.model_args = model_args self.username = username self.password = password @@ -175,6 +168,14 @@ def listen( # **self.model_args, ) + def sequential_locker(): + if self.concurrent_queries: + sequential_lock.acquire() + + def sequential_unlocker(): + if self.concurrent_queries: + sequential_lock.release() + def CORS(): """ Configures Cross-Origin Resource Sharing (CORS) for the server. @@ -219,6 +220,8 @@ def CORS(): # Configure basic authentication conf = { "/": { + "tools.sequential_locker.on": True, + "tools.sequential_unlocker.on": True, "tools.auth_basic.on": True, "tools.auth_basic.realm": "geniusrise", "tools.auth_basic.checkpassword": self.validate_password, @@ -227,11 +230,19 @@ def CORS(): } else: # Configuration without authentication - conf = {"/": {"tools.CORS.on": True}} + conf = { + "/": { + "tools.sequential_locker.on": True, + "tools.sequential_unlocker.on": True, + "tools.CORS.on": True, + } + } + cherrypy.tools.sequential_locker = cherrypy.Tool("before_handler", sequential_locker) cherrypy.tools.CORS = cherrypy.Tool("before_handler", CORS) cherrypy.tree.mount(self, "/api/v1/", conf) cherrypy.tools.CORS = cherrypy.Tool("before_finalize", CORS) + cherrypy.tools.sequential_unlocker = cherrypy.Tool("before_finalize", sequential_unlocker) cherrypy.engine.start() cherrypy.engine.block()