Skip to content

Commit

Permalink
fix: fix stackoverflow on schema circle dep
Browse files Browse the repository at this point in the history
Signed-off-by: he1pa <[email protected]>
  • Loading branch information
He1pa committed Sep 6, 2024
1 parent bf8505d commit c4b9fbd
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 27 deletions.
16 changes: 16 additions & 0 deletions kclvm/sema/src/advanced_resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1526,4 +1526,20 @@ mod tests {
2
);
}

#[test]
fn test_schema_circle_dep() {
let sess = Arc::new(ParseSession::default());

let path = "src/advanced_resolver/test_data/circle_dep.k"
.to_string()
.replace("/", &std::path::MAIN_SEPARATOR.to_string());
let mut program = load_program(sess.clone(), &[&path], None, None)
.unwrap()
.program;
let mut gs = GlobalState::default();
Namer::find_symbols(&program, &mut gs);
let node_ty_map = resolver::resolve_program(&mut program).node_ty_map;
AdvancedResolver::resolve_program(&program, &mut gs, node_ty_map).unwrap();
}
}
8 changes: 8 additions & 0 deletions kclvm/sema/src/advanced_resolver/test_data/circle_dep.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
schema Name(Name):
name: str

schema A(B):
name: str

schema B(A):
name: str
92 changes: 65 additions & 27 deletions kclvm/sema/src/core/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1078,15 +1078,6 @@ impl Symbol for SchemaSymbol {
match self.attributes.get(name) {
Some(attribute) => Some(*attribute),
None => {
if let Some(parent_schema) = self.parent_schema {
if let Some(attribute) =
data.get_symbol(parent_schema)?
.get_attribute(name, data, module_info)
{
return Some(attribute);
}
}

if let Some(for_host) = self.for_host {
if let Some(attribute) =
data.get_symbol(for_host)?
Expand All @@ -1105,6 +1096,25 @@ impl Symbol for SchemaSymbol {
}
}

if let Some(_) = self.parent_schema {
let mut parents = vec![];
parents.push(self.id.unwrap());
self.get_parents(data, &mut parents);
if parents.len() > 1 {
for parent_schema in &parents[1..] {
if let Some(parent_schema) = data.get_schema_symbol(*parent_schema) {
let parent_attr = parent_schema.get_self_attr(data, module_info);
for attr in parent_attr {
if let Some(attribute) = data.get_symbol(attr) {
if attribute.get_name() == name {
return Some(attr);
}
}
}
}
}
}
}
None
}
}
Expand All @@ -1115,24 +1125,17 @@ impl Symbol for SchemaSymbol {
data: &Self::SymbolData,
module_info: Option<&ModuleInfo>,
) -> Vec<SymbolRef> {
let mut result = vec![];
for attribute in self.attributes.values() {
result.push(*attribute);
}
if let Some(parent_schema) = self.parent_schema {
if let Some(parent) = data.get_symbol(parent_schema) {
result.append(&mut parent.get_all_attributes(data, module_info))
}
}

if let Some(for_host) = self.for_host {
if let Some(for_host) = data.get_symbol(for_host) {
result.append(&mut for_host.get_all_attributes(data, module_info))
}
}
for mixin in self.mixins.iter() {
if let Some(mixin) = data.get_symbol(*mixin) {
result.append(&mut mixin.get_all_attributes(data, module_info))
let mut result = self.get_self_attr(data, module_info);
if let Some(_) = self.parent_schema {
let mut parents = vec![];
parents.push(self.id.unwrap());
self.get_parents(data, &mut parents);
if parents.len() > 1 {
for parent in &parents[1..] {
if let Some(schema_symbol) = data.get_schema_symbol(*parent) {
result.append(&mut schema_symbol.get_self_attr(data, module_info))
}
}
}
}
result
Expand Down Expand Up @@ -1233,6 +1236,41 @@ impl SchemaSymbol {
r#ref: HashSet::default(),
}
}

pub fn get_parents(&self, data: &SymbolData, parents: &mut Vec<SymbolRef>) {
if let Some(parent_schema_ref) = self.parent_schema {
if let Some(parent_schema) = data.get_schema_symbol(parent_schema_ref) {
// circular reference
if !parents.contains(&parent_schema_ref) {
parents.push(parent_schema_ref);
parent_schema.get_parents(data, parents);
}
}
}
}

pub fn get_self_attr(
&self,
data: &SymbolData,
module_info: Option<&ModuleInfo>,
) -> Vec<SymbolRef> {
let mut result = vec![];
for attribute in self.attributes.values() {
result.push(*attribute);
}

if let Some(for_host) = self.for_host {
if let Some(for_host) = data.get_symbol(for_host) {
result.append(&mut for_host.get_all_attributes(data, module_info))
}
}
for mixin in self.mixins.iter() {
if let Some(mixin) = data.get_symbol(*mixin) {
result.append(&mut mixin.get_all_attributes(data, module_info))
}
}
result
}
}

#[allow(unused)]
Expand Down

0 comments on commit c4b9fbd

Please sign in to comment.