Skip to content

Commit

Permalink
llm e2e rust+pax
Browse files Browse the repository at this point in the history
  • Loading branch information
warfaj committed Oct 31, 2024
1 parent 84713be commit d7c469f
Show file tree
Hide file tree
Showing 17 changed files with 290 additions and 377 deletions.
2 changes: 1 addition & 1 deletion examples/src/fireworks/pax
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ set -e
current_dir=$(pwd)
pushd ../../../pax-cli
cargo build
PAX_WORKSPACE_ROOT=.. ../target/debug/pax-cli "$@" --path="$current_dir" --libdev --verbose
PUB_PAX_SERVER=http://localhost:8090 PAX_WORKSPACE_ROOT=.. ../target/debug/pax-cli "$@" --path="$current_dir" --libdev --verbose
popd
14 changes: 12 additions & 2 deletions examples/src/fireworks/src/fireworks.pax
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
<Group @wheel=self.handle_wheel>
<Rectangle width=22.20% height=14.21% x=24.31% y=37.05%/>
<Ellipse width=17.64% height=55.47% x=73.65% y=17.71% fill=rgba(0, 0, 255, 128) rotate={(ticks)deg}/>
<Group @wheel=self.handle_wheel y=30.80% x=41.04% height=43.00% width=53.91%>
<Text x=0% y=0% text="SCROLL" height=100% width=100% style={
font: Font::Web(
"Roboto",
Expand All @@ -12,10 +14,12 @@
align_multiline: TextAlignHorizontal::Center
underline: false
}/>
for i in 1..60 {
<Ellipse width=100px height=100px x=50% y=50% fill=rgba(255, 0, 0, 128)/>
for i in 1..200 {
<Rectangle class=rect width=300px height=300px/>
}
</Group>
<Ellipse width=50px height=50px x={bouncing_x} y=80% fill=rgba(0, 255, 0, 128)/>

@settings {
@tick: handle_tick
Expand All @@ -27,4 +31,10 @@
x: 50%
y: 50%
}
.particle {
fill: {hsl((j * 10.00 + ticks)deg, 85%, 55%)}
x: {(j * 2.00 + ticks) %% 100}
y: {(j * 3.00 + ticks) %% 100}
rotate: {(ticks * 2.00)deg}
}
}
4 changes: 2 additions & 2 deletions examples/src/fireworks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ impl Fireworks {

pub fn handle_tick(&mut self, _ctx: &NodeContext) {
let old_ticks = self.ticks.get();
self.ticks.set(old_ticks + 1);
self.ticks.set(old_ticks + 20); // Increment by 20 to make it 10x as fast
}
}
}
1 change: 1 addition & 0 deletions pax-chassis-web/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ log = "0.4.20"
console_error_panic_hook = { version = "0.1.6", optional = true }
js-sys = "0.3.63"
web-time = "1.1.0"
getrandom = { version = "0.2.15", features = ["js"] }

[dependencies.web-sys]
version = "0.3.10"
Expand Down
11 changes: 10 additions & 1 deletion pax-compiler/src/design_server/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use pax_designtime::messages::{
UpdateTemplateRequest,
};
use pax_manifest::{ComponentDefinition, ComponentTemplate, PaxManifest, TypeId};
use std::collections::HashMap;
use std::{collections::HashMap, path::Path};

pub mod socket_message_accumulator;

Expand Down Expand Up @@ -158,6 +158,15 @@ impl StreamHandler<Result<ws::Message, ws::ProtocolError>> for PrivilegedAgentWe
);
self.state.update_last_written_timestamp();
}
Ok(AgentMessage::WriteNewFilesRequest(files)) => {
for (filename, contents) in files {
let root = self.state.userland_project_root.lock().unwrap();
let path = root.join(&filename);
std::fs::write(&path, contents)
.unwrap_or_else(|e| eprintln!("Failed to write file: {}", e));
}
println!("Files written to disk. Time to recompile!");
}
Ok(AgentMessage::LoadFileToStaticDirRequest(load_info)) => {
let LoadFileToStaticDirRequest { name, data } = load_info;
println!(
Expand Down
33 changes: 14 additions & 19 deletions pax-compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,32 +130,27 @@ pub fn perform_build(ctx: &RunContext) -> eyre::Result<(PaxManifest, Option<Path
let mut manifests: Vec<PaxManifest> =
serde_json::from_str(&out).expect(&format!("Malformed JSON from parser: {}", &out));


// Simple starting convention: first manifest is userland, second manifest is designer; other schemas are undefined
let mut userland_manifest = manifests.remove(0);


// Populate project files
let mut project_files: Vec<(String, String)> = vec![];
if let Some(cargo_manifest_dir) = userland_manifest.cargo_manifest_dir.clone() {
let cargo_toml = fs::read_to_string(cargo_manifest_dir.clone() + "/Cargo.toml").unwrap();
project_files.push(("Cargo.toml".to_string(), cargo_toml));
let src_dir = cargo_manifest_dir.clone() + "/src";
let src_files = fs::read_dir(src_dir.clone()).unwrap();
for file in src_files {
let file = file.unwrap();
let file_name = file.file_name().into_string().unwrap();
let file_path = file.path();
let file_contents = fs::read_to_string(file_path).unwrap();
project_files.push(("src/".to_string() + &file_name, file_contents));
}

// i want to read add the cargo.toml from the cargo manifest dir as well as everything in the src/ directory to the project files
let cargo_manifest_dir = userland_manifest.cargo_manifest_dir.clone().unwrap();
let prefix = cargo_manifest_dir.rsplit_once("/").unwrap().1.to_string();

let cargo_toml = fs::read_to_string(cargo_manifest_dir.clone() + "/Cargo.toml").unwrap();
project_files.push((prefix.clone() + "/Cargo.toml", cargo_toml));

let src_dir = cargo_manifest_dir.clone() + "/src";
let src_files = fs::read_dir(src_dir.clone()).unwrap();
for file in src_files {
let file = file.unwrap();
let file_name = file.file_name().into_string().unwrap();
let file_path = file.path();
let file_contents = fs::read_to_string(file_path).unwrap();
project_files.push((prefix.clone() + "/src/" + &file_name, file_contents));
userland_manifest.project_files = project_files;
}

panic!("current running path: {:?}", project_files);

let mut merged_manifest = userland_manifest.clone();

//Hack: add a wrapper component so UniqueTemplateNodeIdentifier is a suitable uniqueid, even for root nodes
Expand Down
22 changes: 14 additions & 8 deletions pax-designer/src/console/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,24 @@ impl Console {
});
while let Some(message_type) = new_messages.pop() {
match message_type {
pax_designtime::orm::MessageType::LLMSuccess(component) => {
pax_designtime::orm::MessageType::LLMSuccess(components) => {
model::with_action_context(&ctx_p, |ctx| {
let t = ctx.transaction("llm update");
if let Err(e) = t.run(|| {
let mut dt = borrow_mut!(ctx.engine_context.designtime);
let orm = dt.get_orm_mut();
orm.replace_template(
component.type_id,
component.template.unwrap_or_default(),
component.settings.unwrap_or_default(),
)
.map_err(|e| anyhow!(e))?;
for component in components {
log::warn!(
"replacing templates: {:?}",
component.type_id
);
orm.replace_template(
component.type_id,
component.template.unwrap_or_default(),
component.settings.unwrap_or_default(),
)
.map_err(|e| anyhow!(e))?;
}
Ok(())
}) {
log::warn!("failed llm message component update {:?}", e);
Expand All @@ -94,7 +100,7 @@ impl Console {
pax_designtime::orm::MessageType::LLMPartial => {
let mut design_time = dt.borrow_mut();
let mut llm_message = design_time.get_llm_messages(current_id);
llm_message.reverse();
//llm_message.reverse();
for message in llm_message {
messages.push(Message {
message_type: MessageType::LLM,
Expand Down
6 changes: 6 additions & 0 deletions pax-designtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ impl DesigntimeManager {
&mut self,
screenshot_map: Rc<RefCell<HashMap<u32, ScreenshotData>>>,
) -> anyhow::Result<()> {
if let Some(files) = self.orm.get_updated_project_files() {
self.priv_agent_connection
.borrow_mut()
.send_updated_files(files)?;
}

if let Some(mut llm_request) = self.enqueued_llm_request.take() {
let mut screenshot_map = screenshot_map.borrow_mut();
if let Some(screenshot) = screenshot_map.remove(&(llm_request.request_id as u32)) {
Expand Down
17 changes: 10 additions & 7 deletions pax-designtime/src/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::fmt::Debug;
pub enum AgentMessage {
ProjectFileChangedNotification(FileChangedNotification),
ManifestSerializationRequest(ManifestSerializationRequest),
WriteNewFilesRequest(Vec<(String, String)>),
// Request to retrieve the manifest from the design server
// sent from designtime to design-server
LoadManifestRequest,
Expand Down Expand Up @@ -72,11 +73,17 @@ impl LLMPartialResponse {
}
}

#[derive(Deserialize, Serialize)]
pub enum ChangeType {
PaxOnly(Vec<ComponentDefinition>),
FullReload(Vec<(String, String)>),
}

#[derive(Deserialize, Serialize)]
pub struct LLMFinalResponse {
pub request_id: u64,
pub message: String,
pub component_definition: ComponentDefinition,
pub changes: ChangeType,
}

impl Debug for LLMFinalResponse {
Expand All @@ -89,15 +96,11 @@ impl Debug for LLMFinalResponse {
}

impl LLMFinalResponse {
pub fn new(
request_id: u64,
message: String,
component_definition: ComponentDefinition,
) -> Self {
pub fn new(request_id: u64, message: String, changes: ChangeType) -> Self {
Self {
request_id,
message,
component_definition,
changes,
}
}
}
Expand Down
20 changes: 15 additions & 5 deletions pax-designtime/src/orm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ pub struct PaxManifestORM {
pub llm_messages: HashMap<u64, Vec<String>>,
pub new_message: Property<Vec<MessageType>>,
pub last_serialized_version: HashMap<TypeId, ComponentDefinition>,
pub updated_project_files: Option<Vec<(String, String)>>,
}

impl PaxManifestORM {
Expand All @@ -104,23 +105,32 @@ impl PaxManifestORM {
llm_messages: HashMap::new(),
new_message: Property::new(vec![]),
last_serialized_version,
updated_project_files: None,
}
}

pub fn set_updated_project_files(&mut self, files: Vec<(String, String)>) {
self.updated_project_files = Some(files);
}

pub fn get_updated_project_files(&mut self) -> Option<Vec<(String, String)>> {
self.updated_project_files.take()
}

pub fn add_new_message(
&mut self,
request_id: u64,
message: String,
component: Option<ComponentDefinition>,
components: Vec<ComponentDefinition>,
) {
self.llm_messages
.entry(request_id)
.or_insert(Vec::new())
.push(message);
self.new_message.update(|msgs| {
msgs.push(match component {
Some(component) => MessageType::LLMSuccess(component),
None => MessageType::LLMPartial,
msgs.push(match components.is_empty() {
true => MessageType::LLMPartial,
false => MessageType::LLMSuccess(components),
})
});
}
Expand Down Expand Up @@ -687,7 +697,7 @@ pub struct SubTrees {
pub enum MessageType {
Serialization(String),
LLMPartial,
LLMSuccess(ComponentDefinition),
LLMSuccess(Vec<ComponentDefinition>),
}

impl Interpolatable for MessageType {}
44 changes: 37 additions & 7 deletions pax-designtime/src/privileged_agent.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
messages::{
AgentMessage, ComponentSerializationRequest, LLMRequest, LoadFileToStaticDirRequest,
AgentMessage, ChangeType, ComponentSerializationRequest, LLMRequest,
LoadFileToStaticDirRequest,
},
orm::PaxManifestORM,
};
Expand Down Expand Up @@ -71,6 +72,18 @@ impl WebSocketConnection {
}
}

pub fn send_updated_files(&mut self, files: Vec<(String, String)>) -> Result<()> {
if self.alive {
let msg_bytes = rmp_serde::to_vec(&AgentMessage::WriteNewFilesRequest(files))?;
self.sender.send(ewebsock::WsMessage::Binary(msg_bytes));
Ok(())
} else {
Err(anyhow!(
"couldn't send updated files: connection to design-server was lost"
))
}
}

pub fn send_llm_request(&mut self, llm_request: LLMRequest) -> Result<()> {
if self.alive {
let msg_bytes = rmp_serde::to_vec(&AgentMessage::LLMRequest(llm_request))?;
Expand Down Expand Up @@ -124,15 +137,32 @@ impl WebSocketConnection {
.map_err(|e| anyhow!(e))?;
}
AgentMessage::LLMPartialResponse(partial) => {
manager.add_new_message(partial.request_id, partial.message, None);
}
AgentMessage::LLMFinalResponse(final_response) => {
manager.add_new_message(
final_response.request_id,
final_response.message,
Some(final_response.component_definition),
partial.request_id,
partial.message,
vec![],
);
}
AgentMessage::LLMFinalResponse(final_response) => {
if let ChangeType::PaxOnly(components) = final_response.changes {
log::warn!("got pax-only changes:");
log::warn!("length: {}", components.len());
manager.add_new_message(
final_response.request_id,
final_response.message,
components,
);
} else if let ChangeType::FullReload(project_files) =
final_response.changes
{
let updated_files: Vec<(String, String)> = project_files
.iter()
.filter(|(f, _)| f.ends_with(".rs") | f.ends_with(".pax"))
.cloned()
.collect();
manager.set_updated_project_files(updated_files);
}
}
_ => {}
}
}
Expand Down
1 change: 1 addition & 0 deletions pax-engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ wasm-bindgen = {version = "0.2.92", optional=true}
wasm-bindgen-futures = {version = "0.4.42", optional=true}
log = "0.4.20"


[features]
gpu = ["pax-chassis-web?/gpu"]
designtime = ["dep:pax-designtime", "pax-runtime/designtime", "pax-chassis-web?/designtime", "pax-chassis-macos?/designtime", "pax-chassis-ios?/designtime", "pax-macro/designtime"]
Expand Down
6 changes: 2 additions & 4 deletions pax-generation/generated_project/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
[package]
name = "generated_project"
version = "0.13.19"
version = "0.38.3"
edition = "2021"
default-run = "run"

[dependencies]
pax-kit = { version = "0.35.0", path="../../pax-kit", features = ["designer"] }
rand = { version = "0.8.5" }
getrandom = { version = "0.2.15", features = ["js"] }
pax-kit = { version = "0.38.3", path="../../pax-kit" }

[lib]
crate-type = ["cdylib", "rlib"]
Expand Down
Loading

0 comments on commit d7c469f

Please sign in to comment.