-
Hello! I am (attempting) to solve a large NLP through Optimization.jl. The problem is quite sparse, so using A part of my NLP is doing inference on an already trained Lux model. When the overall function is being traced, I get a variety of warnings about fallback matmul because of mismatched types (one being floats, the other being the tracers). Previously (I forgot I did this...) I overloaded my wrapper of the Lux model on the tracer types and essentially did I am now trying to do this correctly using the overloads that the package provides. The examples seem to only touch on scalar inputs. Is there anything special/different that must be done for (say) a vector of tracer types as the input? From above, Thanks again, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi Andrew, thanks for the feedback!
Indeed, our focus is on scalar overloads. However, for global sparsity detection (
A pattern that you will see repeated is that we take a union of the input tracers and return a SparseConnectivityTracer.jl/src/overloads/utils.jl Lines 5 to 48 in e6d357b You can see these helper functions and FillArrays being used in the following example: SparseConnectivityTracer.jl/src/overloads/arrays.jl Lines 101 to 106 in e6d357b
Could you open an issue separate issue for this? I think it would be desirable for us to support Lux.jl out of the box. |
Beta Was this translation helpful? Give feedback.
Hi Andrew, thanks for the feedback!
Indeed, our focus is on scalar overloads. However, for global sparsity detection (
TracerSparsityDetector
) we also provide some array-level overloads. These two files should provide you with examples: