Skip to content

Commit

Permalink
eval builtin function types
Browse files Browse the repository at this point in the history
  • Loading branch information
k2d222 committed Feb 14, 2025
1 parent 6aa7350 commit 029d267
Show file tree
Hide file tree
Showing 10 changed files with 644 additions and 69 deletions.
2 changes: 1 addition & 1 deletion crates/wesl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ attributes = ["wgsl-parse/attributes"]
condcomp = ["wgsl-parse/condcomp", "attributes"]
eval = []
generics = ["wgsl-parse/generics", "attributes"]
serde = ["dep:serde"]
serde = ["wgsl-parse/serde", "dep:serde"]
package = ["dep:proc-macro2", "dep:quote"]
635 changes: 587 additions & 48 deletions crates/wesl/src/eval/builtin.rs

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions crates/wesl/src/eval/constant.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Scope, SyntaxUtil};
use super::{is_constructor_fn, Scope, SyntaxUtil};
use itertools::Itertools;
use wgsl_parse::{span::Spanned, syntax::*};

Expand Down Expand Up @@ -261,7 +261,7 @@ impl IsConst for FunctionCall {
// TODO: this is not optimal as it will be recomputed for the same functions.
is_function_const(decl, wesl)
} else {
false
is_constructor_fn(&fn_name)
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions crates/wesl/src/eval/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ pub enum EvalError {
NotType(String),
#[error("unknown type or variable `{0}`")]
UnknownType(String),
#[error("unknown struct `{0}`")]
UnknownStruct(String),
#[error("declaration `{0}` is not accessible at {} time", match .1 {
EvalStage::Const => "shader-module-creation",
EvalStage::Override => "pipeline-creation",
Expand Down
11 changes: 7 additions & 4 deletions crates/wesl/src/eval/eval.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::iter::zip;

use super::{
call_builtin, compound_exec_no_scope, with_scope, Context, Convert, EvalError, EvalStage,
EvalTy, Flow, Instance, LiteralInstance, PtrInstance, RefInstance, StructInstance, SyntaxUtil,
Ty, Type, VecInstance, ATTR_INTRINSIC,
call_builtin, compound_exec_no_scope, is_builtin_fn, with_scope, Context, Convert, EvalError,
EvalStage, EvalTy, Flow, Instance, LiteralInstance, PtrInstance, RefInstance, StructInstance,
SyntaxUtil, Ty, Type, VecInstance, ATTR_INTRINSIC,
};

use half::f16;
Expand Down Expand Up @@ -295,10 +295,11 @@ impl Eval for FunctionCall {
}
// function call
else if let Some(decl) = ctx.source.decl_function(&fn_name) {
if !decl.attributes.contains(&Attribute::Const) && ctx.stage == EvalStage::Const {
if ctx.stage == EvalStage::Const && !decl.attributes.contains(&Attribute::Const) {
return Err(E::NotConst(decl.ident.to_string()));
}

// TODO: this should no longer happen, I deprecated PRELUDE while we stabilize generics and overloads.
if decl.body.attributes.contains(&ATTR_INTRINSIC) {
return call_builtin(&ty, args, ctx);
}
Expand Down Expand Up @@ -356,6 +357,8 @@ impl Eval for FunctionCall {
.inspect_err(|_| ctx.set_err_decl_ctx(decl.ident.to_string())),
};
inst
} else if is_builtin_fn(&fn_name) {
call_builtin(&ty, args, ctx)
}
// not struct constructor and not function
else {
Expand Down
8 changes: 5 additions & 3 deletions crates/wesl/src/eval/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
use wesl_macros::query_mut;
use wgsl_parse::{span::Spanned, syntax::*};

use super::{to_expr::ToExpr, EvalTy, SyntaxUtil, EXPR_FALSE, EXPR_TRUE};
use super::{is_constructor_fn, to_expr::ToExpr, EvalTy, SyntaxUtil, EXPR_FALSE, EXPR_TRUE};

type E = EvalError;

Expand Down Expand Up @@ -318,13 +318,15 @@ impl Lower for Statement {
Statement::Return(stmt) => stmt.lower(ctx)?,
Statement::Discard(_) => (),
Statement::FunctionCall(stmt) => {
let decl = ctx.source.decl_function(&*stmt.call.ty.ident.name());
let decl = ctx.source.decl_function(&stmt.call.ty.ident.name());
if let Some(decl) = decl {
if decl.attributes.contains(&Attribute::Const) {
*self = Statement::Void; // a void const function does nothing
*self = Statement::Void; // a const function has no side-effects
} else {
stmt.lower(ctx)?
}
} else if is_constructor_fn(&stmt.call.ty.ident.name()) {
*self = Statement::Void; // a const function has no side-effects
} else {
stmt.lower(ctx)?
}
Expand Down
15 changes: 14 additions & 1 deletion crates/wesl/src/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ impl<'s> Context<'s> {

pub trait SyntaxUtil {
/// find a global declaration by name.
fn user_decl(&self, name: &str) -> Option<&GlobalDeclaration>;

/// find a global declaration by name, including built-in ones (see `prelude.wgsl`).
fn decl(&self, name: &str) -> Option<&GlobalDeclaration>;

/// find a variable/value declaration by name.
Expand All @@ -238,11 +241,21 @@ pub trait SyntaxUtil {
}

impl SyntaxUtil for TranslationUnit {
fn user_decl(&self, name: &str) -> Option<&GlobalDeclaration> {
// note: declarations in PRELUDE can be shadowed by user-defined declarations.
self.global_declarations.iter().find(|d| match d {
GlobalDeclaration::Declaration(d) => &*d.ident.name() == name,
GlobalDeclaration::TypeAlias(d) => &*d.ident.name() == name,
GlobalDeclaration::Struct(d) => &*d.ident.name() == name,
GlobalDeclaration::Function(d) => &*d.ident.name() == name,
_ => false,
})
}
fn decl(&self, name: &str) -> Option<&GlobalDeclaration> {
// note: declarations in PRELUDE can be shadowed by user-defined declarations.
self.global_declarations
.iter()
.chain(PRELUDE.global_declarations.iter())
// .chain(PRELUDE.global_declarations.iter())
.find(|d| match d {
GlobalDeclaration::Declaration(d) => &*d.ident.name() == name,
GlobalDeclaration::TypeAlias(d) => &*d.ident.name() == name,
Expand Down
32 changes: 23 additions & 9 deletions crates/wesl/src/eval/ty.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::str::FromStr;

use super::{
ArrayInstance, ArrayTemplate, AtomicInstance, AtomicTemplate, Context, EvalError, Instance,
LiteralInstance, MatInstance, MatTemplate, PtrInstance, PtrTemplate, RefInstance,
StructInstance, SyntaxUtil, TextureTemplate, VecInstance, VecTemplate,
builtin_fn_type, is_builtin_fn, ArrayInstance, ArrayTemplate, AtomicInstance, AtomicTemplate,
Context, EvalError, Instance, LiteralInstance, MatInstance, MatTemplate, PtrInstance,
PtrTemplate, RefInstance, StructInstance, SyntaxUtil, TextureTemplate, VecInstance,
VecTemplate,
};

type E = EvalError;
Expand Down Expand Up @@ -475,7 +476,10 @@ impl EvalTy for Expression {
Expression::Parenthesized(expr) => expr.expression.eval_ty(ctx),
Expression::NamedComponent(expr) => match expr.base.eval_ty(ctx)? {
Type::Struct(name) => {
let decl = ctx.source.decl_struct(&name).unwrap();
let decl = ctx
.source
.decl_struct(&name)
.ok_or_else(|| E::UnknownStruct(name.clone()))?;
let mem = decl
.members
.iter()
Expand Down Expand Up @@ -601,11 +605,21 @@ impl EvalTy for Expression {
})
}
Expression::FunctionCall(call) => {
let decl = ctx.source.decl_function(&call.ty.ident.name()).unwrap();
decl.return_type
.as_ref()
.map(|ty| ty.eval_ty(ctx))
.unwrap_or(Ok(Type::Void))
if let Some(decl) = ctx.source.decl_function(&call.ty.ident.name()) {
decl.return_type
.as_ref()
.map(|ty| ty.eval_ty(ctx))
.unwrap_or(Ok(Type::Void))
} else if is_builtin_fn(&call.ty.ident.name()) {
let args = call
.arguments
.iter()
.map(|arg| arg.eval_ty(ctx))
.collect::<Result<Vec<_>, _>>()?;
builtin_fn_type(&call.ty, &args, ctx)
} else {
Err(E::UnknownFunction(call.ty.ident.to_string()))
}
}
Expression::TypeOrIdentifier(ty) => ty.eval_ty(ctx),
}
Expand Down
2 changes: 2 additions & 0 deletions samples/main.wgsl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
const values = array(abs(-340282350000000000000000000000000000000f), abs(-9657247500000000000000000000000000000f), abs(-271270990000000000000000000000000000f), abs(-7523531000000000000000000000000000f), abs(-205307640000000000000000000000000f), abs(-5484528300000000000000000000000f), abs(-150371800000000000000000000000f), abs(-4244365600000000000000000000f), abs(-118425380000000000000000000f), abs(-3256697700000000000000000f), abs(-87893830000000000000000f), abs(-2337089400000000000000f), abs(-66257688000000000000f), abs(-1858791600000000000f), abs(-51469707000000000f), abs(-1401630400000000f), abs(-37338512000000f), abs(-1032194500000f), abs(-29100593000f), abs(-810784600f), abs(-22255488f), abs(-599186.25f), abs(-16049.632f), abs(-454.53058f), abs(-12.734693f), abs(-0.3520408f), abs(-0.009566326f), abs(-0.00025410551f), abs(-0.0000070844376f), abs(-0.00000019949309f), abs(-0.000000005549922f), abs(-0.00000000015205265f));

alias int = i32;
alias uint = u32;
alias float = f32;
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ fn parse_binding(
})?;
Ok((
(b.group, b.binding),
RefInstance::from_instance(inst, storage, access),
RefInstance::new(inst, storage, access),
))
}

Expand Down

0 comments on commit 029d267

Please sign in to comment.