Replies: 2 comments 5 replies
-
It looks like in the first call you're allocating a very large array ( Perhaps you need to ensure that the very large buffer is deleted before attempting further computations? |
Beta Was this translation helpful? Give feedback.
-
Ah, right, I missed the fact that this was jit-compiled. In that case there shouldn't be any intermediate variables floating around. To see what's going on, I generated the compiled HLO for your functions. Here's the flat version: print(tst_flat_jit.lower(10).compile().as_text()) Output
And here's the slice version: print(tst_slice_jit.lower(10, 0).compile().as_text()) Output
The key thing to notice is that in Does that make sense? |
Beta Was this translation helpful? Give feedback.
-
This may seem like a silly thing to do, but it's a somewhat minimal example of something that is confusing me. This script works for reducing the exceedingly large array, but fails on the sub-arrays:
The output I get looks like this:
Surely, doing the reduction with fewer elements would take less memory rather than more (I would have thought).
Beta Was this translation helpful? Give feedback.
All reactions