From df6cc68599397b9199a027015305d7a54abf33a8 Mon Sep 17 00:00:00 2001 From: OpenVMP Date: Mon, 2 Sep 2024 17:54:22 -0700 Subject: [PATCH] Added Ollama and update other AI script generation features (#179) --- docs/source/configuration.rst | 2 +- docs/source/features.rst | 26 +- examples/produce_part_ai_build123d/README.md | 35 +++ examples/produce_part_ai_build123d/cube.py | 6 + examples/produce_part_ai_build123d/cube.svg | 16 ++ .../produce_part_ai_build123d/partcad.yaml | 36 +++ examples/produce_part_ai_build123d/prism.py | 27 ++ examples/produce_part_ai_build123d/prism.svg | 20 ++ .../produce_part_ai_build123d/tetrahedron.py | 27 ++ .../produce_part_ai_build123d/tetrahedron.svg | 12 + examples/produce_part_ai_cadquery/README.md | 8 +- .../produce_part_ai_cadquery/partcad.yaml | 3 +- examples/produce_part_ai_cadquery/prism.py | 25 +- examples/produce_part_ai_cadquery/prism.svg | 28 +- .../produce_part_ai_cadquery/tetrahedron.py | 41 +-- .../produce_part_ai_cadquery/tetrahedron.svg | 22 +- examples/produce_part_ai_openscad/README.md | 8 +- .../produce_part_ai_openscad/partcad.yaml | 3 +- examples/produce_part_ai_openscad/prism.scad | 24 +- examples/produce_part_ai_openscad/prism.svg | 30 +-- partcad/requirements.txt | 3 + partcad/src/partcad/__init__.py | 2 +- partcad/src/partcad/ai.py | 41 ++- partcad/src/partcad/ai_google.py | 56 ++-- partcad/src/partcad/ai_ollama.py | 160 ++++++++++++ partcad/src/partcad/ai_openai.py | 16 +- .../src/partcad/part_factory_ai_build123d.py | 1 + .../src/partcad/part_factory_ai_cadquery.py | 1 + partcad/src/partcad/part_factory_build123d.py | 9 +- .../src/partcad/part_factory_feature_ai.py | 247 +++++++++++++++--- .../partcad/plugin_export_png_reportlab.py | 4 + partcad/src/partcad/shape.py | 4 +- partcad/src/partcad/user_config.py | 36 +++ .../src/partcad/wrappers/wrapper_build123d.py | 3 +- .../src/partcad/wrappers/wrapper_cadquery.py | 3 +- .../src/partcad/wrappers/wrapper_common.py | 19 ++ .../partcad/wrappers/wrapper_render_obj.py | 1 + .../partcad/wrappers/wrapper_render_svg.py | 1 + 38 files changed, 845 insertions(+), 161 deletions(-) create mode 100644 examples/produce_part_ai_build123d/README.md create mode 100644 examples/produce_part_ai_build123d/cube.py create mode 100644 examples/produce_part_ai_build123d/cube.svg create mode 100644 examples/produce_part_ai_build123d/partcad.yaml create mode 100644 examples/produce_part_ai_build123d/prism.py create mode 100644 examples/produce_part_ai_build123d/prism.svg create mode 100644 examples/produce_part_ai_build123d/tetrahedron.py create mode 100644 examples/produce_part_ai_build123d/tetrahedron.svg create mode 100644 partcad/src/partcad/ai_ollama.py diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index 35e89919..f27f96c1 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -441,7 +441,7 @@ Generate CadQuery or OpenSCAD scripts with Generative AI using the following syn parts: : type: - provider: + provider: tokens: <(optional) the limit of token context> top_p: <(optional, openai only) the top_p parameter> images: <(optional) contextual images as input for AI> diff --git a/docs/source/features.rst b/docs/source/features.rst index a0b758f5..5b2edd3f 100644 --- a/docs/source/features.rst +++ b/docs/source/features.rst @@ -17,7 +17,8 @@ Last but not least, they are powering accessibility features, allowing blind users to navigate the catalog of parts or to interactively create their own designs. -Google and OpenAI models are supported. The following configuration is required: +Google Gemini, OpenAI and Ollama APIs are supported. +The following configuration is required: .. code-block:: yaml @@ -25,6 +26,28 @@ Google and OpenAI models are supported. The following configuration is required: googleApiKey: <...> openaiApiKey: <...> +The following configuration is optional: + + .. code-block:: yaml + + # ~/.partcad/config.yaml + # ollamaNumThread is the number of CPU threads Ollama should utilize + ollamaNumThread: + +PartCAD AI agents are designed to query AI multiple times, +so that a range of options is considered and the best result is found. +The following configuration options can be used to influence that bahavior: + + .. code-block:: yaml + + # ~/.partcad/config.yaml + # maxGeometricModeling is the number of attempts for geometric modelling + maxGeometricModeling: 4 + # maxModelGeneration is the number of attempts for CAD script generation + maxModelGeneration: 3 + # maxScriptCorrection is the number of attempts to incrementally fix the script if it's not working + maxScriptCorrection: 2 + Design ------ @@ -42,6 +65,7 @@ The generated part definitions are persisted as Python or CAD scripts. pc inspect "generated-case" To use ChatGPT instead of Gemini, pass "openai" instead of "google" as the "--ai" parameter. +To use Ollama, pass "ollama". If needed, the part can be regenerated by truncating the generated files. diff --git a/examples/produce_part_ai_build123d/README.md b/examples/produce_part_ai_build123d/README.md new file mode 100644 index 00000000..de4a2ca5 --- /dev/null +++ b/examples/produce_part_ai_build123d/README.md @@ -0,0 +1,35 @@ +# /pub/examples/partcad/produce_part_ai_build123d + +PartCAD parts defined using AI-generated build123d scripts. + +## Usage +```shell +pc inspect cube +pc inspect prism +pc inspect tetrahedron +``` + + +## Parts + +### cube + + + +
A cube
+ +### prism + + + +
A hexagonal prism
+ +### tetrahedron + + + +
A tetrahedron
+ +

+ +*Generated by [PartCAD](https://partcad.org/)* diff --git a/examples/produce_part_ai_build123d/cube.py b/examples/produce_part_ai_build123d/cube.py new file mode 100644 index 00000000..3b26dc98 --- /dev/null +++ b/examples/produce_part_ai_build123d/cube.py @@ -0,0 +1,6 @@ +from build123d import * + +with BuildPart() as cube: + Box(length=10, width=10, height=10) + +show_object(cube) \ No newline at end of file diff --git a/examples/produce_part_ai_build123d/cube.svg b/examples/produce_part_ai_build123d/cube.svg new file mode 100644 index 00000000..66bda8a5 --- /dev/null +++ b/examples/produce_part_ai_build123d/cube.svg @@ -0,0 +1,16 @@ + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/produce_part_ai_build123d/partcad.yaml b/examples/produce_part_ai_build123d/partcad.yaml new file mode 100644 index 00000000..cb5e3a2e --- /dev/null +++ b/examples/produce_part_ai_build123d/partcad.yaml @@ -0,0 +1,36 @@ +desc: PartCAD parts defined using AI-generated build123d scripts. + +docs: + usage: | + ```shell + pc inspect cube + pc inspect prism + pc inspect tetrahedron + ``` + +parts: + cube: + type: ai-build123d + provider: google + desc: A cube + properties: + length: 10 + prism: + type: ai-build123d + provider: openai + # provider: ollama # TODO(clairbee): find an Ollama model that works with build123d + desc: A hexagonal prism + properties: + length: 10 + tetrahedron: + type: ai-build123d + provider: openai + tokens: + top_p: 0.9 + desc: A tetrahedron + properties: + length: 10 + +render: + readme: + svg: diff --git a/examples/produce_part_ai_build123d/prism.py b/examples/produce_part_ai_build123d/prism.py new file mode 100644 index 00000000..afd9367c --- /dev/null +++ b/examples/produce_part_ai_build123d/prism.py @@ -0,0 +1,27 @@ +import math +from build123d import * + +# Define the side length of the hexagon +s = 5 # Example side length, can be adjusted + +# Calculate the radius of the circumscribed circle +r = s / math.sqrt(3) + +# Define the vertices of the hexagon +vertices = [ + (r, 0, 0), + (r * math.cos(math.radians(60)), r * math.sin(math.radians(60)), 0), + (r * math.cos(math.radians(120)), r * math.sin(math.radians(120)), 0), + (-r, 0, 0), + (r * math.cos(math.radians(240)), r * math.sin(math.radians(240)), 0), + (r * math.cos(math.radians(300)), r * math.sin(math.radians(300)), 0) +] + +# Create the hexagon base +hexagon = Polygon(vertices) + +# Extrude the hexagon to form the prism +hex_prism = extrude(hexagon, 10) + +# Display the part +show_object(hex_prism) \ No newline at end of file diff --git a/examples/produce_part_ai_build123d/prism.svg b/examples/produce_part_ai_build123d/prism.svg new file mode 100644 index 00000000..d5b5038c --- /dev/null +++ b/examples/produce_part_ai_build123d/prism.svg @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/produce_part_ai_build123d/tetrahedron.py b/examples/produce_part_ai_build123d/tetrahedron.py new file mode 100644 index 00000000..4f06e2dd --- /dev/null +++ b/examples/produce_part_ai_build123d/tetrahedron.py @@ -0,0 +1,27 @@ +import math +from build123d import * + +# Define vertices of the tetrahedron +A = (0, 0, 0) +B = (10, 0, 0) +C = (5, 8.66, 0) +D = (5, 2.89, 4.71) + +# Create edges +edge_AB = Edge.make_line(A, B) +edge_AC = Edge.make_line(A, C) +edge_AD = Edge.make_line(A, D) +edge_BC = Edge.make_line(B, C) +edge_BD = Edge.make_line(B, D) +edge_CD = Edge.make_line(C, D) + +# Create faces from edges +face_ABC = Face.make_from_wires(Wire.make_wire([edge_AB, edge_BC, edge_AC])) +face_ABD = Face.make_from_wires(Wire.make_wire([edge_AB, edge_BD, edge_AD])) +face_ACD = Face.make_from_wires(Wire.make_wire([edge_AC, edge_CD, edge_AD])) +face_BCD = Face.make_from_wires(Wire.make_wire([edge_BC, edge_CD, edge_BD])) + +# Create the tetrahedron by combining faces +tetrahedron = Solid.make_solid(Shell([face_ABC, face_ABD, face_ACD, face_BCD])) + +show_object(tetrahedron) \ No newline at end of file diff --git a/examples/produce_part_ai_build123d/tetrahedron.svg b/examples/produce_part_ai_build123d/tetrahedron.svg new file mode 100644 index 00000000..8befacdb --- /dev/null +++ b/examples/produce_part_ai_build123d/tetrahedron.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/produce_part_ai_cadquery/README.md b/examples/produce_part_ai_cadquery/README.md index 1094fe99..f29b2334 100644 --- a/examples/produce_part_ai_cadquery/README.md +++ b/examples/produce_part_ai_cadquery/README.md @@ -14,20 +14,22 @@ pc inspect tetrahedron ### cube - +
A cube
### prism - +
A hexagonal prism
### tetrahedron - +
A tetrahedron
+

+ *Generated by [PartCAD](https://partcad.org/)* diff --git a/examples/produce_part_ai_cadquery/partcad.yaml b/examples/produce_part_ai_cadquery/partcad.yaml index fcfffe7c..3e78cd3e 100644 --- a/examples/produce_part_ai_cadquery/partcad.yaml +++ b/examples/produce_part_ai_cadquery/partcad.yaml @@ -17,7 +17,8 @@ parts: length: 10 prism: type: ai-cadquery - provider: openai + provider: ollama + model: llama3.1:70b desc: A hexagonal prism properties: length: 10 diff --git a/examples/produce_part_ai_cadquery/prism.py b/examples/produce_part_ai_cadquery/prism.py index a00e3ae5..5ac958b6 100644 --- a/examples/produce_part_ai_cadquery/prism.py +++ b/examples/produce_part_ai_cadquery/prism.py @@ -1,7 +1,22 @@ -import cadquery as cq +import math +from cadquery import * -# Define the hexagonal prism -length = 10 -hex_prism = cq.Workplane("XY").polygon(6, length).extrude(length) +s = 10 # assuming a side length of 10 mm for demonstration purposes -show_object(hex_prism) \ No newline at end of file +# calculate radius of circumscribed circle using trigonometry +r = s / (2 * math.sin(math.radians(60))) + +# create the hexagonal prism +hexagon_prism = ( + Workplane("XY") + .moveTo(r, 0) + .lineTo(r + s/2, r*math.sqrt(3)/2) + .lineTo(-r - s/2, r*math.sqrt(3)/2) + .lineTo(-r, 0) + .lineTo(-r - s/2, -r*math.sqrt(3)/2) + .lineTo(r + s/2, -r*math.sqrt(3)/2) + .close() + .extrude(10) # extrude to create the prism +) + +show_object(hexagon_prism) \ No newline at end of file diff --git a/examples/produce_part_ai_cadquery/prism.svg b/examples/produce_part_ai_cadquery/prism.svg index 0ac5a92f..1a391615 100644 --- a/examples/produce_part_ai_cadquery/prism.svg +++ b/examples/produce_part_ai_cadquery/prism.svg @@ -1,20 +1,18 @@ - + - - - - - - - - - - - - - - + + + + + + + + + + + + \ No newline at end of file diff --git a/examples/produce_part_ai_cadquery/tetrahedron.py b/examples/produce_part_ai_cadquery/tetrahedron.py index e703eb65..c98cf71a 100644 --- a/examples/produce_part_ai_cadquery/tetrahedron.py +++ b/examples/produce_part_ai_cadquery/tetrahedron.py @@ -1,24 +1,31 @@ import math import cadquery as cq -length = 10 +# Define vertices +A = (0, 0, 0) +B = (10, 0, 0) +C = (5, 8.66, 0) +D = (3.33, 2.89, 8.16) -def create_tetrahedron(length): - height = math.sqrt(2/3) * length - radius = math.sqrt(6)/3 * length - - v0 = cq.Vector(0, 0, 0) - v1 = cq.Vector(length, 0, 0) - v2 = cq.Vector(length/2, height, 0) - v3 = cq.Vector(length/2, height/3, radius) - - tetrahedron = cq.Workplane("XY").polyline([v0, v1, v2, v0]) \ - .polyline([v0, v2, v3, v0]) \ - .polyline([v1, v3, v2, v1]) \ - .polyline([v1, v0, v3, v1]) \ - .close() +# Create edges +edges = [ + cq.Edge.makeLine(A, B), + cq.Edge.makeLine(A, C), + cq.Edge.makeLine(A, D), + cq.Edge.makeLine(B, C), + cq.Edge.makeLine(B, D), + cq.Edge.makeLine(C, D) +] - return tetrahedron +# Create faces +faces = [ + cq.Face.makeFromWires(cq.Wire.assembleEdges([edges[0], edges[1], edges[3]])), + cq.Face.makeFromWires(cq.Wire.assembleEdges([edges[0], edges[2], edges[4]])), + cq.Face.makeFromWires(cq.Wire.assembleEdges([edges[1], edges[2], edges[5]])), + cq.Face.makeFromWires(cq.Wire.assembleEdges([edges[3], edges[4], edges[5]])) +] + +# Create solid +tetrahedron = cq.Solid.makeSolid(cq.Shell.makeShell(faces)) -tetrahedron = create_tetrahedron(length) show_object(tetrahedron) \ No newline at end of file diff --git a/examples/produce_part_ai_cadquery/tetrahedron.svg b/examples/produce_part_ai_cadquery/tetrahedron.svg index d8ab636a..08fdfd4c 100644 --- a/examples/produce_part_ai_cadquery/tetrahedron.svg +++ b/examples/produce_part_ai_cadquery/tetrahedron.svg @@ -1,20 +1,12 @@ - + - - - - - - - - - - - - - - + + + + + + \ No newline at end of file diff --git a/examples/produce_part_ai_openscad/README.md b/examples/produce_part_ai_openscad/README.md index 817af456..69edcafd 100644 --- a/examples/produce_part_ai_openscad/README.md +++ b/examples/produce_part_ai_openscad/README.md @@ -14,20 +14,22 @@ pc inspect tetrahedron ### cube - +
A cube
### prism - +
A hexagonal prism
### tetrahedron - +
A tetrahedron
+

+ *Generated by [PartCAD](https://partcad.org/)* diff --git a/examples/produce_part_ai_openscad/partcad.yaml b/examples/produce_part_ai_openscad/partcad.yaml index c025a5f1..562b8287 100644 --- a/examples/produce_part_ai_openscad/partcad.yaml +++ b/examples/produce_part_ai_openscad/partcad.yaml @@ -18,7 +18,8 @@ parts: prism: type: ai-openscad desc: A hexagonal prism - provider: openai + provider: ollama + model: llama3.1:70b properties: length: 10 tetrahedron: diff --git a/examples/produce_part_ai_openscad/prism.scad b/examples/produce_part_ai_openscad/prism.scad index bdf498d3..eceb5fff 100644 --- a/examples/produce_part_ai_openscad/prism.scad +++ b/examples/produce_part_ai_openscad/prism.scad @@ -1,6 +1,22 @@ -module hexagonal_prism(length = 10) { - h = sqrt(3)*length; - cylinder(h=length, r = h*0.5, $fn=6); +s = 10; // side length of the hexagonal base +h = 10; // height of the prism + +r_outer = s / sqrt(3); + +module hexagon(s) { + polygon(points=[ + [s/2, 0], + [s/4, s*sqrt(3)/2], + [-s/4, s*sqrt(3)/2], + [-s/2, 0], + [-s/4, -s*sqrt(3)/2], + [s/4, -s*sqrt(3)/2] + ]); +} + +module hexagonal_prism(s, h) { + linear_extrude(height=h) + hexagon(s); } -hexagonal_prism(); \ No newline at end of file +hexagonal_prism(s, h); \ No newline at end of file diff --git a/examples/produce_part_ai_openscad/prism.svg b/examples/produce_part_ai_openscad/prism.svg index d083060e..db515d16 100644 --- a/examples/produce_part_ai_openscad/prism.svg +++ b/examples/produce_part_ai_openscad/prism.svg @@ -1,20 +1,20 @@ - + - - - - - - - - - - - - - - + + + + + + + + + + + + + + \ No newline at end of file diff --git a/partcad/requirements.txt b/partcad/requirements.txt index 3da6457b..7607ee0c 100644 --- a/partcad/requirements.txt +++ b/partcad/requirements.txt @@ -18,3 +18,6 @@ Pillow # OpenAI openai + +# Locally hosted AI +ollama diff --git a/partcad/src/partcad/__init__.py b/partcad/src/partcad/__init__.py index d4d1804f..c8980ff4 100644 --- a/partcad/src/partcad/__init__.py +++ b/partcad/src/partcad/__init__.py @@ -13,7 +13,7 @@ _partcad_context, render, ) -from .ai import models +from .ai import supported_models from .consts import * from .context import Context from .assembly import Assembly diff --git a/partcad/src/partcad/ai.py b/partcad/src/partcad/ai.py index 47db0781..e405e9da 100644 --- a/partcad/src/partcad/ai.py +++ b/partcad/src/partcad/ai.py @@ -7,16 +7,18 @@ # Licensed under Apache License, Version 2.0. # +import fnmatch import time from .ai_google import AiGoogle from .ai_openai import AiOpenAI +from .ai_ollama import AiOllama from . import logging as pc_logging from .user_config import user_config -models = [ +supported_models = [ "gpt-3.5-turbo", "gpt-4", "gpt-4-vision-preview", @@ -26,10 +28,16 @@ "gemini-pro-vision", "gemini-1.5-pro", "gemini-1.5-flash", + "llama3.1*", + "codellama*", + "codegemma*", + "gemma*", + "deepseek-coder*", + "codestral*", ] -class Ai(AiGoogle, AiOpenAI): +class Ai(AiGoogle, AiOpenAI, AiOllama): def generate( self, action: str, @@ -42,10 +50,10 @@ def generate( ): with pc_logging.Action("Ai" + action, package, item): # Determine the model to use + provider = config.get("provider", None) if "model" in config and config["model"] is not None: model = config["model"] else: - provider = config.get("provider", None) if provider is None: if not user_config.openai_api_key is None: provider = "openai" @@ -64,19 +72,26 @@ def generate( # else: # model = "gpt-4" model = "gpt-4o" + elif provider == "ollama": + model = "llama3.1:70b" else: error = "Provider %s is not supported" % provider pc_logging.error(error) return [] # Generate the content - if not model in models: + is_supported = False + for supported_model_pattern in supported_models: + if fnmatch.fnmatch(model, supported_model_pattern): + is_supported = True + break + if not is_supported: error = "Model %s is not supported" % model pc_logging.error(error) return [] result = [] - if model.startswith("gemini"): + if provider == "google": try: result = self.generate_google( model, @@ -91,7 +106,7 @@ def generate( ) time.sleep(1) # Safeguard against exceeding quota - elif model.startswith("gpt"): + elif provider == "openai": try: result = self.generate_openai( model, @@ -106,6 +121,20 @@ def generate( ) time.sleep(1) # Safeguard against exceeding quota + elif provider == "ollama": + try: + result = self.generate_ollama( + model, + prompt, + image_filenames, + config, + num_options, + ) + except Exception as e: + pc_logging.error( + "Failed to generate with Ollama: %s" % str(e) + ) + else: pc_logging.error( "Failed to associate the model %s with the provider" % model diff --git a/partcad/src/partcad/ai_google.py b/partcad/src/partcad/ai_google.py index 12fd2ac1..d9ecf088 100644 --- a/partcad/src/partcad/ai_google.py +++ b/partcad/src/partcad/ai_google.py @@ -40,27 +40,15 @@ def google_once(): with lock: if pil_image is None: - try: - pil_image = importlib.import_module("PIL.Image") - except Exception as e: - pc_logging.exception(e) - return + pil_image = importlib.import_module("PIL.Image") if google_genai is None: - try: - google_genai = importlib.import_module("google.generativeai") - except Exception as e: - pc_logging.exception(e) - return + google_genai = importlib.import_module("google.generativeai") if google_api_core_exceptions is None: - try: - google_api_core_exceptions = importlib.import_module( - "google.api_core.exceptions" - ) - except Exception as e: - pc_logging.exception(e) - return + google_api_core_exceptions = importlib.import_module( + "google.api_core.exceptions" + ) latest_key = user_config.google_api_key if latest_key != GOOGLE_API_KEY: @@ -70,9 +58,7 @@ def google_once(): return True if GOOGLE_API_KEY is None: - error = "Google API key is not set" - pc_logging.error(error) - return False + raise Exception("Google API key is not set") return True @@ -94,9 +80,24 @@ def generate_google( else: tokens = model_tokens[model] if model in model_tokens else None + if "top_p" in config: + top_p = config["top_p"] + else: + top_p = None + + if "top_k" in config: + top_k = config["top_k"] + else: + top_k = None + + if "temperature" in config: + temperature = config["temperature"] + else: + temperature = None + images = list( map( - lambda f: PIL.Image.open(f), + lambda f: pil_image.open(f), image_filenames, ) ) @@ -105,13 +106,17 @@ def generate_google( client = google_genai.GenerativeModel( model, generation_config={ - # "candidate_count": options_num, + "candidate_count": 1, + # "candidate_count": options_num, # TODO(clairbee): not supported yet? not any more? "max_output_tokens": tokens, - # "temperature": 0.97, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, }, ) candidates = [] - for _ in range(options_num): + options_left = options_num + while options_left > 0: retry = True while retry == True: retry = False @@ -132,7 +137,10 @@ def generate_google( error = "%s: Failed to generate content" % self.name pc_logging.error(error) continue + + options_created = len(response.candidates) candidates.extend(response.candidates) + options_left -= options_created if options_created > 0 else 1 products = [] try: diff --git a/partcad/src/partcad/ai_ollama.py b/partcad/src/partcad/ai_ollama.py new file mode 100644 index 00000000..74d8cf43 --- /dev/null +++ b/partcad/src/partcad/ai_ollama.py @@ -0,0 +1,160 @@ +# +# OpenVMP, 2024 +# +# Author: Roman Kuzmenko +# Created: 2024-08-24 +# +# Licensed under Apache License, Version 2.0. +# + +import importlib +import httpx +import threading +import time +from typing import Any + +# Lazy-load AI imports as they are not always needed +# import ollama +ollama = None + +from . import logging as pc_logging +from .user_config import user_config + +lock = threading.Lock() + +model_tokens = {} + +ollama_num_thread = None + +models_pulled = {} + + +def ollama_once(): + global ollama, ollama_num_thread + + with lock: + if ollama is None: + ollama = importlib.import_module("ollama") + + ollama_num_thread = user_config.ollama_num_thread + + return True + + +def model_once(model: str): + global models_pulled + + ollama_once() + + with lock: + if not model in models_pulled: + pc_logging.info("Pulling %s..." % model) + ollama.pull(model) + models_pulled[model] = True + + +class AiOllama: + def generate_ollama( + self, + model: str, + prompt: str, + image_filenames: list[str] = [], + config: dict[str, Any] = {}, + options_num: int = 1, + ): + model_once(model) + + pc_logging.info( + "Generating with Ollama: asking for %d alternatives", options_num + ) + + if not ollama_once(): + return None + + if len(image_filenames) > 0: + raise NotImplementedError("Images are not supported by Ollama") + + if "tokens" in config: + tokens = config["tokens"] + else: + tokens = model_tokens[model] if model in model_tokens else None + + if "top_p" in config: + top_p = config["top_p"] + else: + top_p = None + + if "top_k" in config: + top_k = config["top_k"] + else: + top_k = None + + if "temperature" in config: + temperature = config["temperature"] + else: + temperature = None + + candidates = [] + for _ in range(options_num): + retry = True + while retry == True: + retry = False + try: + options = ollama.Options( + tokens=tokens, + num_thread=ollama_num_thread, + top_p=top_p, + top_k=top_k, + temperature=temperature, + ) + response = ollama.generate( + model=model, + context=[], # do not accumulate context uncontrollably + prompt=prompt, + options=options, + ) + except httpx.ConnectError as e: + pc_logging.exception(e) + pc_logging.error( + "Failed to connect to Ollama. Is it running?" + ) + retry = True + time.sleep(15) + except ollama._types.ResponseError as e: + pc_logging.exception(e) + pc_logging.error( + "Failed to generate with Ollama: %s" % str(e) + ) + pc_logging.warning( + f"Consider running 'ollama run {model}' first..." + ) + if "not found" in str(e): + retry = True + time.sleep(15) + else: + time.sleep(1) + except Exception as e: + pc_logging.exception(e) + retry = True + time.sleep(1) + + if not response or not "response" in response: + error = "%s: Failed to generate content" % self.name + pc_logging.error(error) + return + + pc_logging.info("Response: %s", str(response)) + # Perform Ollama-specific output sanitization + response_text = response["response"] + response_text = response_text.replace("\\'", "'") + candidates.append(response_text) + + products = [] + try: + for candidate in candidates: + if candidate: + products.append(candidate) + except Exception as e: + pc_logging.exception(e) + + return products diff --git a/partcad/src/partcad/ai_openai.py b/partcad/src/partcad/ai_openai.py index 4e0900f6..83b87423 100644 --- a/partcad/src/partcad/ai_openai.py +++ b/partcad/src/partcad/ai_openai.py @@ -40,11 +40,7 @@ def openai_once(): with lock: if openai_genai is None: - try: - openai_genai = importlib.import_module("openai") - except Exception as e: - pc_logging.exception(e) - return False + openai_genai = importlib.import_module("openai") latest_key = user_config.openai_api_key if latest_key != OPENAI_API_KEY: @@ -55,9 +51,7 @@ def openai_once(): return True if OPENAI_API_KEY is None: - error = "OpenAI API key is not set" - pc_logging.error(error) - return False + raise Exception("OpenAI API key is not set") return True @@ -84,6 +78,11 @@ def generate_openai( else: top_p = None + if "temperature" in config: + temperature = config["temperature"] + else: + temperature = None + content = [ {"type": "text", "text": prompt}, *list( @@ -113,6 +112,7 @@ def generate_openai( n=options_num, max_tokens=tokens, top_p=top_p, + temperature=temperature, model=model, ) diff --git a/partcad/src/partcad/part_factory_ai_build123d.py b/partcad/src/partcad/part_factory_ai_build123d.py index 41d6e038..f57bd35e 100644 --- a/partcad/src/partcad/part_factory_ai_build123d.py +++ b/partcad/src/partcad/part_factory_ai_build123d.py @@ -16,6 +16,7 @@ class PartFactoryAiBuild123d(PartFactoryBuild123d, PartFactoryFeatureAi): def __init__(self, ctx, source_project, target_project, config): # Override the path determined by the parent class to enable "enrich" config["path"] = config["name"] + ".py" + self.lang = self.LANG_PYTHON mode = "builder" if "mode" in config and config["mode"] == "algebra": diff --git a/partcad/src/partcad/part_factory_ai_cadquery.py b/partcad/src/partcad/part_factory_ai_cadquery.py index 56e7bb30..8634eecf 100644 --- a/partcad/src/partcad/part_factory_ai_cadquery.py +++ b/partcad/src/partcad/part_factory_ai_cadquery.py @@ -16,6 +16,7 @@ class PartFactoryAiCadquery(PartFactoryCadquery, PartFactoryFeatureAi): def __init__(self, ctx, source_project, target_project, config): # Override the path determined by the parent class to enable "enrich" config["path"] = config["name"] + ".py" + self.lang = self.LANG_PYTHON with pc_logging.Action("InitAiCq", target_project.name, config["name"]): PartFactoryFeatureAi.__init__( diff --git a/partcad/src/partcad/part_factory_build123d.py b/partcad/src/partcad/part_factory_build123d.py index 3e6273b3..44a3308c 100644 --- a/partcad/src/partcad/part_factory_build123d.py +++ b/partcad/src/partcad/part_factory_build123d.py @@ -114,7 +114,12 @@ async def instantiate(self, part): if len(errors) > 0: error_lines = errors.split("\n") for error_line in error_lines: - part.error("%s: %s" % (part.name, error_line)) + error_line = error_line.strip() + if not error_line: + continue + # TODO(clairbee): Move the part name concatenation to where the logging happens + # part.error("%s: %s" % (part.name, error_line)) + part.error(error_line) try: # pc_logging.error("Response: %s" % response_serialized) @@ -128,7 +133,7 @@ async def instantiate(self, part): return None if not result["success"]: - part.error("%s: %s" % (part.name, result["exception"])) + pc_logging.error("Failed to produce the part: %s" % part.name) return None self.ctx.stats_parts_instantiated += 1 diff --git a/partcad/src/partcad/part_factory_feature_ai.py b/partcad/src/partcad/part_factory_feature_ai.py index 29104779..90097b1f 100644 --- a/partcad/src/partcad/part_factory_feature_ai.py +++ b/partcad/src/partcad/part_factory_feature_ai.py @@ -18,14 +18,18 @@ from . import logging as pc_logging from . import sync_threads as pc_thread from .utils import total_size +from .user_config import user_config -NUM_ALTERNATIVES_GEOMETRIC_MODELING = 3 -NUM_ALTERNATIVES_MODEL_GENERATION = 3 +DEFAULT_ALTERNATIVES_GEOMETRIC_MODELING = 3 +DEFAULT_ALTERNATIVES_MODEL_GENERATION = 3 +DEFAULT_INCREMENTAL_SCRIPT_CORRECTION = 2 class PartFactoryFeatureAi(Ai): """Used by all part factories that generate parts with GenAI.""" + LANG_PYTHON = "Python" + part_type: str script_type: str prompt_suffix: str @@ -42,17 +46,48 @@ def __init__(self, config, part_type, script_type, prompt_suffix=""): raise Exception(error) self.ai_config = copy.deepcopy(config) - if not "numGeometricModeling" in self.ai_config: - self.ai_config["numGeometricModeling"] = ( - NUM_ALTERNATIVES_GEOMETRIC_MODELING - ) - if not "numModelGeneration" in self.ai_config: - self.ai_config["numModelGeneration"] = ( - NUM_ALTERNATIVES_MODEL_GENERATION + if "numGeometricModeling" in self.ai_config: + self.num_geometric_modling = self.ai_config["numGeometricModeling"] + else: + self.num_geometric_modeling = ( + DEFAULT_ALTERNATIVES_GEOMETRIC_MODELING ) + if ( + user_config.max_geometric_modeling is not None + and self.num_geometric_modeling > user_config.max_geometric_modeling + ): + self.num_geometric_modeling = user_config.max_geometric_modeling + + if "numModelGeneration" in self.ai_config: + self.num_model_generation = self.ai_config[("numModelGeneration")] + else: + self.num_model_generation = DEFAULT_ALTERNATIVES_MODEL_GENERATION + if ( + user_config.max_model_generation is not None + and self.num_model_generation > user_config.max_model_generation + ): + self.num_model_generation = user_config.max_model_generation + + if "numScriptCorrection" in self.ai_config: + self.num_script_correction = self.ai_config[("numScriptCorrection")] + else: + self.num_script_correction = DEFAULT_INCREMENTAL_SCRIPT_CORRECTION + if ( + user_config.max_script_correction is not None + and self.num_script_correction > user_config.max_script_correction + ): + self.num_script_correction = user_config.max_script_correction + if not "tokens" in self.ai_config: self.ai_config["tokens"] = 2000 + # Use `temperature` and `top_p` values recommended for code generation + # if no other preferences are set + if not "temperature" in self.ai_config: + self.ai_config["temperature"] = 0.2 + if not "top_p" in self.ai_config: + self.ai_config["top_p"] = 0.1 + def on_init_ai(self): """This method must be executed at the very end of the part factory constructor to finalize the AI initialization. At the time of the call @@ -82,7 +117,13 @@ def _create_file(self, path): # Geometric modeling modeling_options = [] + max_models = self.num_geometric_modeling + max_tries = 2 * max_models + # TODO(clairbee): Multiple AI calls are a good candidate to paralelize, + # however it's useless without a paid subscription + # with huge quotas. De-prioritized for now. + # # def modeling_task(): # modeling_options.extend(self._geometric_modeling()) # threads = [] @@ -93,12 +134,14 @@ def _create_file(self, path): # for thread in threads: # thread.join() - while len(modeling_options) < self.ai_config["numGeometricModeling"]: + tries = 0 + while len(modeling_options) < max_models and tries < max_tries: modeling_options.extend(self._geometric_modeling()) pc_logging.info( "Generated %d geometric modeling candidates" % len(modeling_options) ) + tries += 1 # For each remaining geometric modeling option, # generate a model and render an image @@ -119,9 +162,11 @@ def _create_file(self, path): % (candidate_id, script) ) - # Render an image of the model - image_filename = self._render_image(script, candidate_id) - + # Validate the image by rendering it, + # attempt to correct the script if rendering doesn't work + image_filename, script = self._validate_and_fix( + modeling_option, script, candidate_id + ) # Check if the model was valid if image_filename is not None: # Record the valid model and the image @@ -197,9 +242,7 @@ def _geometric_modeling(self): self.name, prompt, self.ai_config, - self.ai_config[ - "numGeometricModeling" - ], # NUM_ALTERNATIVES_GEOMETRIC_MODELING, + self.num_geometric_modeling, image_filenames=image_filenames, ) return options @@ -207,7 +250,12 @@ def _geometric_modeling(self): def _generate_script(self, geometric_modeling): """This method generates a script given specific geometric modeling.""" - prompt = """Generate a %s to define a 3D model of a part defined by + prompt = """You are an AI assistant in an engineering department. +You are helping engineers to create programmatic scripts that produce CAD geometry data +for parts, mechanisms, buildings or anything else. +The scripts you create a fully functional and can be used right away, as is, in automated workflows. +Assume that the scripts you produce are used automatically to render 3D models and to validate them. +This time you are asked to generate a %s to define a 3D model of a part defined by the following geometric modeling: %s """ % ( @@ -236,7 +284,6 @@ def _generate_script(self, geometric_modeling): self.name, prompt, self.ai_config, - self.ai_config["numModelGeneration"], image_filenames=image_filenames, ) @@ -245,10 +292,17 @@ def _generate_script(self, geometric_modeling): return scripts - def _sanitize_script(self, script): + def _sanitize_script(self, script: str): """Cleans up the GenAI output to keep the code only.""" - # Remove code blocks + # Extract the first block between ``` if any + loc_1 = script.find("```") + loc_2 = script[loc_1 + 1 :].find("```") + loc_1 + if loc_2 > 0: + # Note: the first ``` (and whatever follows on the same line) is still included + script = script[loc_1:loc_2] + + # Remove ```` if anything is left script = "\n".join( list( filter( @@ -262,10 +316,122 @@ def _sanitize_script(self, script): ) ) + if hasattr(self, "lang") and self.lang == self.LANG_PYTHON: + # Strip straight to the import statements (for AIs that don't ```) + if script.startswith("from "): + loc_from = 0 + else: + loc_from = script.find("\nfrom ") + if script.startswith("import "): + loc_import = 0 + else: + loc_import = script.find("\nimport ") + + if loc_from != -1 and loc_import != -1: + loc = min(loc_import, loc_from) + elif loc_from != -1: + loc = loc_from + else: + loc = loc_import + + if loc != -1: + script = script[loc:] + return script + def _validate_and_fix(self, modeling_option, script, candidate_id, depth=0): + """ + Validate the image by rendering it, + attempt to correct the script if rendering doesn't work. + """ + next_depth = depth + 1 + + # Render this script into an image. + # We can't just stop at validating geometry, + # as we need to feed the picture into the following AI logic. + image_filename, error_text = self._render_image(script, candidate_id) + if image_filename is not None: + return image_filename, script + # Failed to render the image. + + if next_depth <= self.num_script_correction: + # Ask AI to make incremental fixes based on the errors. + correction_candidate_id = 0 + for _ in range(self.num_script_correction): + corrected_scripts = self._correct_script( + modeling_option, script, error_text + ) + corrected_scripts = list( + map(lambda s: self._sanitize_script(s), corrected_scripts) + ) + for corrected_script in corrected_scripts: + pc_logging.debug( + "Corrected the script candidate %d, correction candidate %d at depth %d: %s" + % ( + candidate_id, + correction_candidate_id, + depth, + corrected_script, + ) + ) + + image_filename, corrected_script = self._validate_and_fix( + modeling_option, + corrected_script, + candidate_id, + next_depth, + ) + if image_filename is not None: + return image_filename, corrected_script + + correction_candidate_id += 1 + + return None, script + + def _correct_script(self, modeling_option, script, error_text): + # TODO(clairbee): prove that the use of geometric modeling product + # in this prompt is benefitial + prompt = """You are an AI assistant to a mechanical engineer. +You are given an automatically generated %s which has flaws that need to be +corrected. + +The generated script that contains errors is (until SCRIPT_END): +``` +%s +``` +SCRIPT_END + +When the given script is executed, the following error messages are +produced (until ERRORS_END): +%s +ERRORS_END + +Please, generate a corrected script so that it does not produce the given errors. +Make as little changes as possible and prefer not to make any changes that are +not necessary to fix the errors. +Very important not to produce exactly the same script: at least something has to change. +""" % ( + self.script_type, + script, + error_text, + ) + + pc_logging.debug("Correction prompt: %s" % prompt) + + options = self.generate( + "ScriptIncr", + self.project.name, + self.name, + prompt, + self.ai_config, + self.num_script_correction, + ) + return options + def _render_image(self, script, candidate_id): """This method ensures the validity of a script candidate by attempting to render it.""" + error_text = "" + exception_text = "" source_path = tempfile.mktemp(suffix=self.extension) with open(source_path, "w") as f: @@ -291,6 +457,7 @@ def _render_image(self, script, candidate_id): pc_logging.debug("Part created: %.2f KB" % (total_size(part) / 1024.0)) def render(part): + nonlocal exception_text try: coro = part.get_shape() with pc_logging.Action( @@ -301,33 +468,42 @@ def render(part): part.render_png(self.ctx, None, output_path) except Exception as e: part.error("Failed to render the image: %s" % e) - - try: - # Given we don't know whether the current thread is already - # running an event loop or not, we need to create a new thread - t = threading.Thread(target=render, args=[part]) - t.start() - t.join() - except Exception as e: - part.error("Failed to render the image: %s" % e) - - if len(part.errors) > 0: + exception_text += f"Exception:\n{str(e)}\n" + + # TODO(clairbee): make it async up until this point, drop threads + # Given we don't know whether the current thread is already + # running an event loop or not, we need to create a new thread + t = threading.Thread(target=render, args=[part]) + t.start() + t.join() + + errors = part.errors # Save the errors + errors = list( + filter( + lambda e: "Exception while deserializing" not in e, + errors, + ) + ) + if len(errors) > 0: pc_logging.debug( "There were errors while attemtping to render the image" ) - pc_logging.debug("%s" % part.errors) + pc_logging.debug("%s" % errors) + for error in errors: + error_text += f"{error}\n" + error_text = error_text + exception_text os.unlink(source_path) if not os.path.exists(output_path) or os.path.getsize(output_path) == 0: pc_logging.info( "Script candidate %d: failed to render the image" % candidate_id ) - return None + return None, error_text pc_logging.info( "Script candidate %d: successfully rendered the image" % candidate_id ) - return output_path + return output_path, error_text def select_best_image(self, script_candidates): """Iterate over script_candidates and compare images. @@ -352,6 +528,9 @@ def select_best_image(self, script_candidates): ) # Ask AI to compare the images + pc_logging.info( + "Attempting to select the best script by comparing images" + ) responses = self.generate( "Compare", self.project.name, diff --git a/partcad/src/partcad/plugin_export_png_reportlab.py b/partcad/src/partcad/plugin_export_png_reportlab.py index 87a9bd82..26c74f03 100644 --- a/partcad/src/partcad/plugin_export_png_reportlab.py +++ b/partcad/src/partcad/plugin_export_png_reportlab.py @@ -32,6 +32,10 @@ def export(self, project, svg_path, width, height, filepath): # Render the raster image drawing = svglib.svg2rlg(svg_path) + if drawing is None: + pc_logging.error('Failed to convert to RLG. Aborting.') + return + scale_width = float(width) / float(drawing.width) scale_height = float(height) / float(drawing.height) scale = min(scale_width, scale_height) diff --git a/partcad/src/partcad/shape.py b/partcad/src/partcad/shape.py index 208f7b4d..2b52c144 100644 --- a/partcad/src/partcad/shape.py +++ b/partcad/src/partcad/shape.py @@ -37,10 +37,11 @@ class Shape(ShapeConfiguration): svg_url: str # shape: None | b3d.TopoDS_Shape | OCP.TopoDS.TopoDS_Solid - errors: list[str] = [] + errors: list[str] def __init__(self, config): super().__init__(config) + self.errors = [] self.lock = asyncio.Lock() self.shape = None self.components = [] @@ -232,6 +233,7 @@ async def render_svg_somewhere( request_serialized = base64.b64encode(picklestring).decode() runtime = ctx.get_python_runtime(python_runtime="none") + await runtime.ensure("cadquery") # SVG wrapper requires cq-serialize await runtime.ensure("build123d") response_serialized, errors = await runtime.run( [ diff --git a/partcad/src/partcad/user_config.py b/partcad/src/partcad/user_config.py index 82ffba0a..b703530c 100644 --- a/partcad/src/partcad/user_config.py +++ b/partcad/src/partcad/user_config.py @@ -86,5 +86,41 @@ def __init__(self): else: self.openai_api_key = None + # option: ollamaNumThread + # description: Ask Ollama to use the given number of CPU threads + # values: + # default: None + if "ollamaNumThread" in self.config_obj: + self.ollama_num_thread = int(self.config_obj["ollamaNumThread"]) + else: + self.ollama_num_thread = None + + # option: maxGeometricModeling + # description: the number of attempts for geometric modelling + # values: + # default: None + if "maxGeometricModeling" in self.config_obj: + self.max_geometric_modeling = int(self.config_obj["maxGeometricModeling"]) + else: + self.max_geometric_modeling = None + + # option: maxModelGeneration + # description: the number of attempts for CAD script generation + # values: + # default: None + if "maxModelGeneration" in self.config_obj: + self.max_model_generation = int(self.config_obj["maxModelGeneration"]) + else: + self.max_model_generation = None + + # option: maxScriptCorrection + # description: the number of attempts to incrementally fix the script if it's not working + # values: + # default: None + if "maxScriptCorrection" in self.config_obj: + self.max_script_correction = int(self.config_obj["maxScriptCorrection"]) + else: + self.max_script_correction = None + user_config = UserConfig() diff --git a/partcad/src/partcad/wrappers/wrapper_build123d.py b/partcad/src/partcad/wrappers/wrapper_build123d.py index 6769e406..a4dda0ca 100644 --- a/partcad/src/partcad/wrappers/wrapper_build123d.py +++ b/partcad/src/partcad/wrappers/wrapper_build123d.py @@ -74,8 +74,7 @@ def process(path, request): build_result = script_object.build(build_parameters=build_parameters) if not build_result.success: - sys.stderr.write("Exception: ") - sys.stderr.write(str(build_result.exception)) + wrapper_common.handle_exception(build_result.exception, path) results = list() # TODO(clairbee): make it recursive to handle nested lists and unify handling of items on different levels of nesting diff --git a/partcad/src/partcad/wrappers/wrapper_cadquery.py b/partcad/src/partcad/wrappers/wrapper_cadquery.py index 11ce46bc..eff5671e 100644 --- a/partcad/src/partcad/wrappers/wrapper_cadquery.py +++ b/partcad/src/partcad/wrappers/wrapper_cadquery.py @@ -60,8 +60,7 @@ def process(path, request): build_result = script_object.build(build_parameters=build_parameters) if not build_result.success: - sys.stderr.write("Exception: ") - sys.stderr.write(str(build_result.exception)) + wrapper_common.handle_exception(build_result.exception, path) shapes = [] for result in build_result.results: diff --git a/partcad/src/partcad/wrappers/wrapper_common.py b/partcad/src/partcad/wrappers/wrapper_common.py index d42983b6..b7d09e1f 100644 --- a/partcad/src/partcad/wrappers/wrapper_common.py +++ b/partcad/src/partcad/wrappers/wrapper_common.py @@ -51,3 +51,22 @@ def handle_output(model): picklestring = pickle.dumps(model) response = base64.b64encode(picklestring) print(response.decode()) + + +def handle_exception(exc, cqscript=None): + sys.stderr.write("Error: [") + sys.stderr.write(str(exc).strip()) + sys.stderr.write("] on the line: [") + tb = exc.__traceback__ + # Switch to the traceback object that contains the script line number + tb = tb.tb_next + # Get the filename + fname = tb.tb_frame.f_code.co_filename + if cqscript is not None and fname == "": + fname = cqscript + + # Get the line contents + with open(fname, "r") as fp: + line = fp.read().split("\n")[tb.tb_lineno - 1] + sys.stderr.write(line.strip()) + sys.stderr.write("]\n") diff --git a/partcad/src/partcad/wrappers/wrapper_render_obj.py b/partcad/src/partcad/wrappers/wrapper_render_obj.py index 99ac618b..4ee4d066 100644 --- a/partcad/src/partcad/wrappers/wrapper_render_obj.py +++ b/partcad/src/partcad/wrappers/wrapper_render_obj.py @@ -51,6 +51,7 @@ def process(path, request): "exception": None, } except Exception as e: + wrapper_common.handle_exception(e) return { "success": False, "exception": e, diff --git a/partcad/src/partcad/wrappers/wrapper_render_svg.py b/partcad/src/partcad/wrappers/wrapper_render_svg.py index 35c91eb7..90d12979 100644 --- a/partcad/src/partcad/wrappers/wrapper_render_svg.py +++ b/partcad/src/partcad/wrappers/wrapper_render_svg.py @@ -67,6 +67,7 @@ def process(path, request): "exception": None, } except Exception as e: + wrapper_common.handle_exception(e) return { "success": False, "exception": str(e.with_traceback(None)),