-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass immutable arg_id_to_dtype to InKernelCallables #906
Conversation
4e57b84
to
011b5ee
Compare
@inducer I can't say I'm a fan of this. Maybe a nicer workaround would be to have |
I'm torn. On the one hand, I agree it's a mess. On the other hand, I dislike all the implicit casting in constructor-like interfaces ( A few practical concerns next up:
|
Yeah, agreed. My main worry was that it's a pain to find all the places that should be wrapped, but it's slowly getting there by rummaging through the warnings.
Definitely fine with waiting for that to get in first.
Maybe it can be fixed on this side? I haven't looked too much into it.. |
I'm sure the fix would be straightforward, which I'm happy to do. |
It's in! |
I have a branch that I am currently testing where I think this is fixed. I am happy for you to merge this when ready and then I will update our fork and get my things merged. |
1bb84f8
to
8c7accf
Compare
I'm note quite sure where to use batch mutation stuff in here (if any?). Most places that cast to |
8c7accf
to
f169d82
Compare
loopy/library/function.py
Outdated
new_arg_id_to_dtype = dict(arg_id_to_dtype) | ||
for i in range(len(arg_id_to_dtype)): | ||
if i in arg_id_to_dtype and arg_id_to_dtype[i] is not None: | ||
new_arg_id_to_dtype[-i-1] = new_arg_id_to_dtype[i] | ||
|
||
return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype, | ||
return (self.copy(arg_id_to_dtype=constantdict(new_arg_id_to_dtype), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could probably be smth like:
new_arg_id_to_dtype = constantdict(arg_id_to_dtype).mutate()
for i in range(len(arg_id_to_dtype)):
if i in arg_id_to_dtype and arg_id_to_dtype[i] is not None:
new_arg_id_to_dtype[-i-1] = new_arg_id_to_dtype[i]
return (self.copy(arg_id_to_dtype=new_arg_id_to_dtype.finish()),
This avoids a second copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see! Thanks a lot for the example. I'll make those changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I got all the places where this made sense (at least in the parts that were already modified). Let me know if you spot anything else!
3321e22
to
395d313
Compare
395d313
to
b86f8a1
Compare
@inducer This should be ready for a look now. @connorjward It seems like the Firedrake failure was on this side.. I prematurely cast something to |
Thanks! In it goes. |
This uses
immutabledict
whenever possible. It's not exactly clear to me if some of these should be passed in as immutable higher up the stack, so someone more knowledgeable should have a serious look.Fixes #905.