Skip to content

Commit

Permalink
support truly nested types in where input
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo committed Jan 14, 2025
1 parent 77168df commit f3e3f6c
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 123 deletions.
268 changes: 148 additions & 120 deletions crates/torii/graphql/src/object/inputs/where_input.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use std::str::FromStr;

use async_graphql::dynamic::{
Field, InputObject, InputValue, ResolverContext, TypeRef, ValueAccessor,
};
use async_graphql::dynamic::{Field, InputObject, InputValue, ResolverContext, TypeRef, ValueAccessor};
use async_graphql::{Error as GqlError, Name, Result};
use dojo_types::primitive::{Primitive, SqlType};
use strum::IntoEnumIterator;
Expand All @@ -16,96 +13,81 @@ use crate::types::TypeData;
pub struct WhereInputObject {
pub type_name: String,
pub type_mapping: TypeMapping,
pub nested_inputs: Vec<WhereInputObject>,
}

impl WhereInputObject {
fn build_mapping(prefix: &str, types: &TypeMapping) -> TypeMapping {
types
.iter()
.filter(|(_, type_data)| !type_data.is_list())
.flat_map(|(type_name, type_data)| {
let field_name = if prefix.is_empty() {
type_name.to_string()
} else {
format!("{}_{}", prefix.replace('.', "_"), type_name)
};

if type_data.type_ref() == TypeRef::named("Enum")
|| type_data.type_ref() == TypeRef::named("bool")
{
return vec![(Name::new(field_name), type_data.clone())];
}

// Handle nested types
if type_data.is_nested() {
if let TypeData::Nested((_, nested_types)) = type_data {
return nested_types
.iter()
.flat_map(|(nested_name, nested_type)| {
if !nested_type.is_nested()
|| nested_type.type_ref() == TypeRef::named("Enum")
{
let nested_field = format!("{}_{}", field_name, nested_name);
return Comparator::iter().fold(
vec![(Name::new(&nested_field), nested_type.clone())],
|mut acc, comparator| {
let name =
format!("{}{}", nested_field, comparator.as_ref());
match comparator {
Comparator::In | Comparator::NotIn => acc.push((
Name::new(name),
TypeData::List(Box::new(nested_type.clone())),
)),
_ => {
acc.push((
Name::new(name),
nested_type.clone(),
));
}
}
acc
},
);
}

if let TypeData::Nested((_, further_nested_types)) = nested_type {
let new_prefix = format!("{}_{}", field_name, nested_name);
return Self::build_mapping(&new_prefix, further_nested_types)
.into_iter()
.collect();
}

vec![]
})
.collect();
fn build_field_mapping(type_name: &str, type_data: &TypeData) -> Vec<(Name, TypeData)> {
if type_data.type_ref() == TypeRef::named("Enum")
|| type_data.type_ref() == TypeRef::named("bool") {
return vec![(Name::new(type_name), type_data.clone())];
}

Comparator::iter()
.fold(vec![(Name::new(type_name), type_data.clone())], |mut acc, comparator| {
let name = format!("{}{}", type_name, comparator.as_ref());
match comparator {
Comparator::In | Comparator::NotIn => acc.push((
Name::new(name),
TypeData::List(Box::new(type_data.clone())),
)),
_ => {
acc.push((Name::new(name), type_data.clone()));
}
}

// Handle regular fields with comparators
Comparator::iter().fold(
vec![(Name::new(&field_name), type_data.clone())],
|mut acc, comparator| {
let name = format!("{}{}", field_name, comparator.as_ref());
match comparator {
Comparator::In | Comparator::NotIn => acc.push((
Name::new(name),
TypeData::List(Box::new(type_data.clone())),
)),
_ => {
acc.push((Name::new(name), type_data.clone()));
}
}
acc
},
)
acc
})
.collect()
}

pub fn new(type_name: &str, object_types: &TypeMapping) -> Self {
let where_mapping = Self::build_mapping("", object_types);
let mut nested_inputs = Vec::new();
let mut where_mapping = TypeMapping::new();

for (field_name, type_data) in object_types {
if !type_data.is_list() {
match type_data {
TypeData::Nested((_, nested_types)) => {
// Create nested input object
let nested_input = WhereInputObject::new(
&format!("{}_{}", type_name, field_name),
nested_types,
);

// Add field for the nested input using TypeData::Nested
where_mapping.insert(
Name::new(field_name),
TypeData::Nested((
TypeRef::named(&nested_input.type_name),
nested_types.clone()
))
);
nested_inputs.push(nested_input);
}
_ => {
// Add regular field with comparators
for (name, mapped_type) in Self::build_field_mapping(field_name, type_data) {
where_mapping.insert(name, mapped_type);
}
}
}
}
}

Self {
type_name: format!("{}WhereInput", type_name),
type_mapping: where_mapping,
nested_inputs,
}
}
}

Self { type_name: format!("{}WhereInput", type_name), type_mapping: where_mapping }
impl WhereInputObject {
pub fn input_objects(&self) -> Vec<InputObject> {
let mut objects = vec![self.input_object()];
for nested in &self.nested_inputs {
objects.extend(nested.input_objects());
}
objects
}
}

Expand All @@ -125,59 +107,105 @@ impl InputObjectTrait for WhereInputObject {
}
}


pub fn where_argument(field: Field, type_name: &str) -> Field {
field.argument(InputValue::new("where", TypeRef::named(format!("{}WhereInput", type_name))))
}

fn parse_nested_where(
input_object: &ValueAccessor,
type_name: &str,
type_data: &TypeData,
) -> Result<Vec<Filter>> {
match type_data {
TypeData::Nested((_, nested_mapping)) => {
let nested_input = input_object.object()?;
nested_mapping
.iter()
.filter_map(|(field_name, field_type)| {
nested_input.get(field_name).map(|input| {
let nested_filters = parse_where_value(
input,
&format!("{}.{}", type_name, field_name),
field_type,
)?;
Ok(nested_filters)
})
})
.collect::<Result<Vec<_>>>()
.map(|filters| filters.into_iter().flatten().collect())
}
_ => Ok(vec![]),
}
}

fn parse_where_value(
input: ValueAccessor,
field_path: &str,
type_data: &TypeData,
) -> Result<Vec<Filter>> {
println!("Parsing where value for {}: {:?}", field_path, type_data);
match type_data {
TypeData::Simple(_) => {
if type_data.type_ref() == TypeRef::named("Enum") {
let value = input.string()?;
return Ok(vec![parse_filter(&Name::new(field_path), FilterValue::String(value.to_string()))]);
}

let primitive = Primitive::from_str(&type_data.type_ref().to_string())?;
let filter_value = match primitive.to_sql_type() {
SqlType::Integer => parse_integer(input, field_path, primitive)?,
SqlType::Text => parse_string(input, field_path, primitive)?,
};

Ok(vec![parse_filter(&Name::new(field_path), filter_value)])
}
TypeData::List(inner) => {
let list = input.list()?;
let values = list
.iter()
.map(|value| {
let primitive = Primitive::from_str(&inner.type_ref().to_string())?;
match primitive.to_sql_type() {
SqlType::Integer => parse_integer(value, field_path, primitive),
SqlType::Text => parse_string(value, field_path, primitive),
}
})
.collect::<Result<Vec<_>>>()?;

Ok(vec![parse_filter(&Name::new(field_path), FilterValue::List(values))])
}
TypeData::Nested(_) => {
println!("Processing nested type for {}", field_path);
parse_nested_where(&input, field_path, type_data)
}
}
}

pub fn parse_where_argument(
ctx: &ResolverContext<'_>,
where_mapping: &TypeMapping,
) -> Result<Option<Vec<Filter>>> {
println!("Parsing where argument");
ctx.args.get("where").map_or(Ok(None), |where_input| {
println!("Where input: {:?}", where_input.as_value());
let input_object = where_input.object()?;
println!("Input object: {:?}", input_object.as_index_map());
where_mapping
.iter()
.filter_map(|(type_name, type_data)| {
input_object.get(type_name).map(|input| match type_data {
TypeData::Simple(_) => {
if type_data.type_ref() == TypeRef::named("Enum") {
let value = input.string().unwrap();
return Ok(Some(parse_filter(
type_name,
FilterValue::String(value.to_string()),
)));
}

let primitive = Primitive::from_str(&type_data.type_ref().to_string())?;
let filter_value = match primitive.to_sql_type() {
SqlType::Integer => parse_integer(input, type_name, primitive)?,
SqlType::Text => parse_string(input, type_name, primitive)?,
};

Ok(Some(parse_filter(type_name, filter_value)))
}
TypeData::List(inner) => {
let list = input.list()?;
let values = list
.iter()
.map(|value| {
let primitive = Primitive::from_str(&inner.type_ref().to_string())?;
match primitive.to_sql_type() {
SqlType::Integer => parse_integer(value, type_name, primitive),
SqlType::Text => parse_string(value, type_name, primitive),
}
})
.collect::<Result<Vec<_>>>()?;

Ok(Some(parse_filter(type_name, FilterValue::List(values))))
}
_ => Err(GqlError::new("Nested types are not supported")),
.filter_map(|(field_name, type_data)| {
println!("Processing field: {} with type: {:?}", field_name, type_data);
input_object.get(field_name).map(|input| {
println!("Found input for field {}: {:?}", field_name, input.as_value());
parse_where_value(input, field_name, type_data)
})
})
.collect::<Result<Option<Vec<_>>>>()
.collect::<Result<Vec<_>>>()
.map(|filters| Some(filters.into_iter().flatten().collect()))
})
}


fn parse_integer(
input: ValueAccessor<'_>,
type_name: &str,
Expand Down
7 changes: 6 additions & 1 deletion crates/torii/graphql/src/object/model_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ impl BasicObject for ModelDataObject {

impl ResolvableObject for ModelDataObject {
fn input_objects(&self) -> Option<Vec<InputObject>> {
Some(vec![self.where_input.input_object(), self.order_input.input_object()])
let mut objects = vec![];
objects.extend(self.where_input.input_objects());
objects.push(self.order_input.input_object());
Some(objects)
}

fn enum_objects(&self) -> Option<Vec<Enum>> {
Expand All @@ -99,8 +102,10 @@ impl ResolvableObject for ModelDataObject {
let mut conn = ctx.data::<Pool<Sqlite>>()?.acquire().await?;
let order = parse_order_argument(&ctx);
let filters = parse_where_argument(&ctx, &where_mapping)?;
println!("Filters: {:?}", filters);
let connection = parse_connection_arguments(&ctx)?;


let total_count = count_rows(&mut conn, &table_name, &None, &filters).await?;
let (data, page_info) = fetch_multiple_rows(
&mut conn,
Expand Down
4 changes: 2 additions & 2 deletions crates/torii/graphql/src/query/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ pub fn parse_filter(input: &Name, value: FilterValue) -> Filter {
for comparator in Comparator::iter() {
if let Some(field) = input.strip_suffix(comparator.as_ref()) {
return Filter {
field: field.to_string().replace('_', "."),
field: field.to_string(),
comparator: comparator.clone(),
value,
};
}
}

// If no suffix found assume equality comparison
Filter { field: input.to_string().replace('_', "."), comparator: Comparator::Eq, value }
Filter { field: input.to_string(), comparator: Comparator::Eq, value }
}

0 comments on commit f3e3f6c

Please sign in to comment.