¶ How batch size affects token cost and speed
Today I'm interviewing Reiner Pope, who is CEO of Mad X, which is a new chip startup. Previously, he was doing TPU architecture and many other things at Google. This is a very different format from my usual interviews. This is going to be a blackboard lecture, we're going to get up in a second. We in fact built this whole new studio with specifically this format in mind. Um and so it's a pleasure to get to inaugurate it with you.
We're gonna be talking about model architecture, ML infra, many other things. And um the reason I think it's an important topic is because once you actually understand how training and inference actually work in a cluster, as we'll see, a lot of things about
About
why AI is the way it is, why AI architectures are the way they are, why um API prices are the way they are, fundamentally also how how why AI progress is the way it is. start making sense and you need to understand the details to get there and you need a blackboard to understand the details. So Reiner, thank you so much for doing this.
Very happy to be here.
Just a heads up, this is a lecture with graphs and equations and all that stuff. So if you can, I would really recommend watching it on a video platform like YouTube. Uh full disclosure, I am a angel investor in MadX, but that's not related to this podcast.
Um
Reiner, maybe to kick us off, I'll ask this question. So we have a couple of companies like Claude and Codex and Cursor offering something like fast mode, where for 6x the price, they'll give stream you tokens at 2.5x the speed. Mechanically I'm curious what's going on here. Why well like why is it the case if you can pay more to get faster latency? And two, could you keep going? Could you pay a hundred X more and somehow get even faster speeds or much, much faster speeds?
And three, could you go the other way? Could you have something like uh quad code slow mode where uh if you are willing to wait for minutes on end, you could get um even cheaper prices. So maybe this will help motivate the kind of analysis that you'll be doing through the lecture.
Great. I mean to jump to a little bit to jump to the conclusion, the big effect is batch size, but we're what we're gonna do now is quantify exactly what that looks like and what its implications are on latency and cost. Uh there's going to be another effect, which is you can call it speculative decoding or multi-token prediction. We can maybe come back to that later, but I think the first thing that we'll talk through is batch size.
So what I'd like to introduce is sort of the two principles of analysis. Firstly, we're going to look at a roofline analysis of how we run a transformer model on a cluster of chips. Um we'll take a sort of uh let's say a uh Blackwell NVL 72 uh cluster, so a rack of 72 GPUs And so the roofline analysis means we look at uh memory bandwidth and and c compute uh performance.
And then the other side of that is that we're going to look at just two simple factors of the model, which are the time to operate on the weights and then the time to operate on the context, the KB count. So let's jump in. What we're going to try and do is we're going to try and estimate the time that it takes uh to to run an inference of a certain shape.
We're not perfect here. We can't uh exactly predict the time. And so instead we're going to approximate, and so we're going to say that the time must be greater than or equal to a certain quantity. And so we're gonna consider two different um aspects. We're gonna look at the time for uh it takes to uh do the memory fetches, uh, and then the time it takes to do the compute.
And it'll turn out that this actually gives us a very strong predictive power, even with this simple model. So, one by one, what is the time that it takes to do the compute?
🔇 Silence
So there are really two things I need to do in the compute. I need to multiply by all of the active parameters, and then I need to do some work on the attention. So multiplying by all the active parameters, I have a certain batch size that I'm running. And then I've got a number of uh active parameters in my model. And then um and then I'm just gonna divide this by the compute throughput, which is uh the f flops of the chip. So this is a hardware constant.
So th this this actually accounts for all of the compute time for all of the weight matrix multiplies. Um there's a little caveat here. We we've sort of ignored the time to do any of the attention computation, but that in general can be will be quite small in comparison to this.
Maybe I'll just interrupt from time to time to ask some very naive questions or to clarify some uh basic points. Just for the audience, you're not serving one user at a time. The batch refers to the fact that you're serving many different users at the same time.
Yeah.
Um and that's a whole bad.
Yeah, so I can motivate the batch at least a little bit. So um I mean we will see exactly why batch is such a favorable optimization, but what will turn out to be the case is that Uh if you do not batch together many users, um the cost and the economics you get is can be like a thousand times worse than than if you do batch many two users together. And and we'll be able to see that quite explicitly.
And then uh number of active parameters. This is saying like if I think look at, for example, a deep seq model, uh the deep seek v3 model has about thirty, thirty-seven billion active parameters, and then uh 700 billion total parameters. So this is we're we're focusing on just the ones that are active for a single token.
Okay. So we're modeled compute performance. I'm gonna keep writing equals, but in in all of these cases you can think of this time as being at least this much. And and maybe there'll be some terms we ignored. On the memory side, um what do we need to do uh with memory? We need to fetch all of the weights, and so there is some time to fetch all of the total number of parameters, not just the active parameters.
So there's wait fetch time. And then in addition, there's a KV cache fetch time. So there is, um this actually depends on batch size. Uh so for every element of the batch, we have to fetch uh an entire context length. Worth of tokens, and then there's a size per token. So uh like bytes.
Um just back in, let's let's just explain what the KV cache is real quick. Yeah.
Yeah. So when I do a forward pass, uh let me draw actually a um how the auto-aggressive inference works. So this is during decode. Um So if I think I have a bunch of tokens uh uh text, I'm drawing a tensor because uh ultimately the tokens are represented as some like tensor of uh in some uh embedding dimension, and then in this direction I have the sequence length.
The work of running a decode is I have to run each token through a through a whole bunch of matrix multipliers over a bunch of different layers. Um and I have I have in general I'm gonna have to do that work over uh all of these uh tokens. But then one step of decode is actually to produce just this one additional token up here.
And so what I'm gonna do there is I'm gonna run a full forwards pass of uh multiplying by all of the white matrices uh in the entire model. Um but then I've got this attention mechanism where this token sort of It's it's like looking at all of the past tokens um in this way. And what is it looking at specifically? It is looking at some internal representation that the model has produced of the tokens, and we call that the KBK.
So this process of attending this this single token attending to all of the history of tokens, um that's attention, it is mostly uh dominated by memory factors rather than um than matrix multipliers. So we've got the amount of memory that we're fetching shown over here. And then this is of course just then divided by the uh memory bandwidth. Um so uh so the memory
While it's besser gear.
So in fact these equations here are actually uh enough for us to now some draw some fit lines. And so the things that we'd like to look at are sensitivity to batch, and then also um which we'll draw separately to context lengths. So we said that the big big effects you can get is like some some trade-off in latency versus uh versus cost um in in batch size. So let's draw them out. I think there's just really two graphs we want to draw. Um we'll first just draw um Batch size versus uh time here.
So when we look at the shape of this, we've got a maximum of well the sum and then and then another term. Um so let's look at these terms one by one and how they scale uh uh the time for compute and and memory uh and how they show up. So let's first look at this compute time. Uh this is just purely linearly linear in batch size with no um no offset. So it is some curve like this. This is this is T compute.
And then on the memory side, we've got some portion here that uh that is just this constant that um that is you know uh constant in some base offset here, which is the uh weight fetch. And then finally we have this term here. Which is the KB fetch, um which we're gonna draw as as There's the KB function. Which is which are linear in batch size. Um so it looks like that. So the sum of this plus this maxed with this. So let's at least first to draw the sum. Um
🔇 Silence
So, the two memory times in in conjunction end up looking on this curved slope like this. And then we get a um the overall maximum is I'll draw a little thicker here. This is the maximum of of these two curves. Okay, so so so what does what does this mean actually? So this is a latency um plot. So if I grow my batch size, I um I get initially some not very strong dependence on batch size. And so there's some lower bound or latency here, um latency lower bound. Lower bound.
🔇 Silence
Um so this already partially answers the question. For a given hardware configuration, well then we can talk about varying hardware configuration, but for a given hardware configuration, there is a lower bound on latency, which is simply the I need to read all of my total parameters um from uh from memory into the into the chips. And that takes a certain amount of time. Uh if if I use all of my memory bandwidth, I can't do any better than that.
Uh it seems like the way you've drawn the slopes for compute time and how the K V grows. Uh and what imp implication the KB has on memory uh time.
Yeah, what if what if this were above or below or
Yeah, or is that necessarily the case?'Cause then by the m if this is always true then this batch size grows. Compute always dominates uh KV and which which suggests that if you have a big enough batch size, maybe memory is never an issue.
Yeah, this is really sensitive to the context length. Um so I think we should come back and explore this. Yeah. The there will be as you vary the context length, the KB fetch time will go up and up. And so that'll cause a transition from uh compute limited to memory limit.
Is there something especially significant about the slope being exactly the slope of the um the compute time?
Yeah, whenever we have balance points, it kind of says that you're getting it exactly right. Um and so for the particular context length where the slopes match. Um that says I am equally memory bound and compute bound, which is a really desirable place to.
But suppose it's like Algebra problem but Suppose it's you know, y the the optimal is a hundred k context length. And you go to 200k context length. Does your MFU go down to like 50%? Like is does it have a hum humongous impact on MFU?
Yeah.
To be like slightly outside of context length, uh optimal range, Goldie Lockstone.
That's right. So th that is true as modeled here. Um there is a key point here that I'm modeling this context length as uh or I'm modeling the memory fetch as linear in context length. That actually depends on model architecture. It is true for many of the or all of the model architectures with dense attention. Yeah. There's a sparse attention actually scales much better than that.
Got it. And it is parcel attention that everybody uses in practice?
I'm pretty excited about sparse attention. Uh it's hard to know what the labs are using. Deepseek has published a sparse attention mechanism. I'll just like put a plug in that sparse attention Some of the Deep Seek papers that have published sparse attention end up putting a square root in this term. Okay, so so far we've done we've looked at the latency. Um it's kind of hard to read off cost from this. Uh so if I think what does cost mean? Um
I'm gonna like to run this inference, I'm gonna use the GPU for a certain number of seconds, like one millisecond or twenty milliseconds or something like that. Um and I have to pay the rental time for for that for that time. So like it's two dollars an hour per GPU or something like that.
So
So that's the cost of this inference, but how much value have how many tokens have I processed during that inference? That is the batch size. And so what we actually want to plot is gonna be the um the cost versus batch size, um which is like T over B versus batch size. Uh this is the cost per token. Like we have to imagine dividing each of these three curves by by B, so multiplying by this um r reciprocal. Um and so what we end up with there is the the compute curve is gonna um
It was linear, we divide by B, that makes it uh a constant here. And this is t compute. The um the KV fetch was linear, now it becomes a constant as well. Um uh uh KV fetch. Um the the weight fetch uh was constant and now we've divided by B, and so it becomes this um hyperbole.
🔇 Silence
And so again, we're going to compute the the max of the sum. So the sum of these two terms shifts the the uh the parabola up, sum of the kb fetch and the weight fetch um gives us a sort of a A higher parabola that's like this. And then we're gonna take the max with the compute. Uh uh So we end up with this this being the overall shape that we care about. So again, so like w we see some limiting behavior.
Initially starts very high at batch size of one. Actually, like it almost goes to infinity. Like uh it's um because we've got so many weight fetches which are not amortized over a large batch size. Um but then as we increase the batch size, the weight fetch has become amortized over so many different batch elements that they their cost goes grows very small and eventually the compute time uh ends up driving the cost. So there is a limiting um like lower bound lower bound on cost.
🔇 Silence
Um which is this one here. Um
So Claude Code Slow or Codex Slow or whatever would just live on this line and it wouldn't help much because you're you're not able to amortize the KV values over a much bigger battery.
Yeah, yeah. They're unique per batch. The compute is also unique per batch. And so what is the minimal work you can do per batch after amortizing everything else away?
So this point where you are no longer um memory bandwidth bound. What practically how big a batch do you need to like how yeah, how big are the batches practically for frontier models?
Um you can you can just solve for that actually. Um and it's not even particularly sensitive to model architecture. So um let's go ahead and do that. So what we're talking about is we're gonna say when the memory time is equal to the compute time. That's th that that's what that question is. Um For now I'm gonna discard the um
B because we're focused on what what the batch size is, and really there's a question of what uh when the weights are amortized over the um the the multiplies, I'm gonna focus on comparing the weight fetch time to the weight multiply time. I'm gonna disregard the KB fetch term. Um just just to simplify the analysis so we can get a kind of a clean answer out. Um so we're gonna equate uh This portion. With this, with these two times. Yep.
So writing that out, um we get n number of total parameters. uh is equal to um Batch size times number of active parameters. Divided by the compute performance. So looking over here, everything on the top, these are model parameters. Everything on the bottom, these are hardware parameters. Um it it turns out to be nice to rearrange them such that we have the hardware parameters on one side. So that's this is equivalent to.
Amen.
being equal to um batch size times number of active parameters. Divided by the number of total parameters. So this is a hardware parameter. Um actually, the this actually ends up being a dimensionless constant uh if you look in terms of flop. What are the dimensions of this? This is um multiplies per second, this is bytes per second, so that's not quite dimensionless.
But what do you do is you say like multiplies per second times, let's say I'm doing FP4. Um so I I do like how many FP4 multiplies per second? times the fact that uh each one, each FP4 is half a byte. Um, and so I can actually make this end up ending up being dimensionless. Um and and this ends up being on most GPUs um around 300.
Somewhere around three hundred.
And sorry, has that ratio changed over time as we've gone from model generation to model generation where the flops keeps increasing?
So there's a hardware parameter. Um to what extent has the hardware changed? So um from like A100 to H100 to B100, um the the flops has increased substantially, the memory bandwidth has also increased substantially, and it has remained reasonably stable. Yeah, and we can we can express this one as well. This is a sparsity parameter. Um and I I might even phrase it slightly different. Let's solve for batch size in total. Um we end up with um so we're just moving this back over to the other side.
We end up with batch size needs to be bigger than approximately um 300 times sparsity. So for example, if I have a hundred like I activate in DeepSeek, uh I activate thirty-two out of two hundred and fifty-six experts, so this would be like eight. Got it. Okay. So so this actually gives you a ballpark which is like r remarkably accurate to practice. Generally people will go a little bit larger than this. They don't really want to be exactly at the c uh balance point because
Um real world efficiencies aren't as good as a uh roofline analysis would say. Um but like take this and maybe double it or triple it.
Okay. So basically it's like two to three thousand tokens per batch. But then if you included the KB cash.
Yes.
The implication would be that the optimal batch size
should grow larger. So this is get like we we solve for the equivalence between when um compute time is equal to memory time. If I add in more memory bandwidth, like something that consumes more memory bandwidth, then I have less available for the for the white loads, and so I need to grow the uh the memory bandwidth more, th and therefore the batch size more.
This seems incredibly small. Like a batch this would be like less than one sequence, right?
Yeah, okay. So so I guess this is um keep in mind that I'm talking about the number of tokens that I'm generating one more token for. So so it's like it's actually two thousand unique sequences in the right.
Okay, we're just talking about the a single forward pass. Like the b do you think of it like the bash as the number sequences rather than like When I'm prepping for interviews, I often talk to experts in the field. So for Reiner, I chatted with two of Jane Street's engineers, Clark and Axel. Clark, who works on low latency trading systems, walked me through why Jane Street uses FPGAs to make sure that they have predictable nanosecond latencies.
You can just do it.
Exactly what you need to touch 100 megabytes of SRAM and then get your response back in tens of nanoseconds very easily. And that's basically important. He then went on to explain why CPUs just wouldn't work for this kind of thing. And so if you have a clock that's going every three nanoseconds, you actually have several bytes of information at a time to make your decision. That's as opposed to a CPU where you'll just collect up a whole lot
Pack it.
You know, let's say a 1500 byte packet. And then you say, okay, this packet's ready. Here you go, CPU. You can start thinking about it now. FPGAs allow you to react to the earliest part of the packet. as it arrives, rather than having to wait for the full thing. We also talked about liquid cooling, network design, and many other things. If you're interested in this stuff, Jane Street is hiring. You can check out their open roles at JaneStreet.com slash dwarf.
And if you want to watch the full prep conversation, we posted it there too. If you've got a frontier model and you are actually doing inference, Surely they must have more than two thousand concurrent users.
Yeah.
Is there any added latency from the fact that you need to have the whole batch fill up? Or is it if you have a reasonable amount of users, it's so unlikely that you wouldn't Yeah.
Yeah. The the way to think about this, I guess we don't think of it as like when does the train depart as a model. So Let's say I've I've picked a batch size that I'm gonna run at. Maybe I pick the you know this batch size. Um and so like well and by the way, this intersection point sh is is the same intersection point here. Um
So I pick this batch size. I know that it's gonna take, for example, maybe it's something like twenty milliseconds is a common place this ends up landing. What I'm going to produce is uh Like so this is a timeline of what is running on the GPU. It's gonna start a new batch every 20 milliseconds, uh regardless. And so uh so so each of the this is 20, this is forty. I miss.
🔇 Silence
You can think of this as a schedule for the train. A new train departs every 20 milliseconds. Any passengers who are ready board the train. If the train is full, then they wait to the next train. If the train is not full, full, the train's gonna go anyway. And so in terms of what that means for
queuing latency, it means that the worst case is that you like a a a request arrives just after the train departed. It has to wait for the next train, so that's up to twenty milliseconds, and then it has to wait for that train to complete. And so the worst case lag is 40 minutes.
So how is the training milliseconds derived?
Um I mean, rule of thumb, but where it comes from is not fully explained yet, but um So far, we've focused on memory bandwidth and compute uh time. Uh when we look at memory, the other consideration is that we want to use all of the memory capacity we have. And so generally we're going to use all of that memory capacity to store the weights or the KBs.
And so
We just want to read, like in the time of doing a forward pass, maybe we want to read all of the memory capacity into into the chip. Um and so that is capacity divided by bandwidth. That tends to be twenty milliseconds on on many gen different generations of HPM.
The units make sense. You would have a A bite divided by bytes per second.
Yeah, so for example, I mean on on I think the Ruben generation, it is something like 288 gigabytes um divided by 20 terabytes per second. Um uh and This looks like it comes out to about fifteen milliseconds.
Yeah. I mean I understand how why the units can't the the sort of unit analysis. But what is it saying is We can evacuate and replace the HBM in this amount of time. And so we don't want to be in a situation where the H B M is not big enough that we're not gonna actually able to Keep write everything you want to it or take everything out of it. Or we don't want to be in a situation where our ability to write back and forth is so h big or sorry, so small component.
Trevor Burrus, Yeah, there's sort of two scenarios. Why don't we pick a latency that is bigger than fifteen milliseconds? Um if I think what that means, it means I actually have time to read the HPM like twice. By the way, most of HPM accesses is is reads, not writes. It's like almost all reads because the weight matrices are read-only, and then almost all of the KB cache accesses are reads. So um in like let's say I run 30 milliseconds, I can read all of HPM twice.
What's the point of that? Like I I don't want to read the white matrices twice. Um I don't want to read the KVs twice.
Yeah, it makes sense. Makes a ton of sense. Okay, so a c a couple of actually quick questions. One, if it is the case that the optimal batch size is something like two thousand, and that I actually true it's totally dependent on the sparsity, it's not dependent on the model size or anything.
I mean sparsity shows up in model size, but beyond that it only depends on sparsity, not on scale.
But that's a very interesting result. And that seems to imply that you can push toward centralization is it that you would have these economies of scale from inference, from batching. Yeah. But it seems like it's not that big a deal. Like I don't know, two thous is two thousand users at the same time a lot? It doesn't seem like a lot.
So w we can do a bit of analysis on this, which would be actually it's like you can think of it in terms of uh number of users, but maybe a more productive way to think of it is in terms of number of Tokens per second. So what does this batch size uh mean in terms of tokens per second of this of the system? So um tokens per second. Um tokens per second is going to be equal to the batch size. So we run a batch many tokens, and then we do that every um T.
Every time it turns, which is let's say which is uh which is this thing is equal to the 15 milliseconds, 20 milliseconds number. So um this ends up being batch size itself times Uh about sixty. So um like sixty-four times B. Um and so this ends up being around Uh two thousand times sixty four. So like 128 um engine twenty-eight K uh token specific
So this is sort of in more digestible units. Like uh it's hard to reason about concurrent users, but what is the trap global traffic for a system? Um When you look at some of the announcements, sometimes the API providers will brag about how much traffic they have. Um the the the numbers that I've remembered from some announcements of Gemini last year were in the hundreds of millions of tokens per second worldwide. So so uh about a thousand like this is one thousandth of that printer.
But I mean the Gemini's big so I th that's actually one thousandth of Gemini is a lot. Um To be competitive at scale, you need to be able to do that.
Yeah.
Serve at least one thousand.
Yeah.
Interesting. Okay. So The more sparsity you have, the less compute you need. And it does seem that as batch sizes get bigger, compute ends up being the bottleneck. According to this analysis. So then the question is, how far can you take sparsity? That is to say, as the sparsity ratio increases, as you have fewer and fewer active parameters relative to total parameters.
How much is performance of the model degrading? And is it degrading faster than your saving compute by increasing the sparsity factor?
Yeah, so performance quality of the of the model, you I mean rather than speed of the model. So unfortunately, we're not able to answer that analytically. That's um that is an empirical question of model quality. Best I can do is pull up a paper and answer that empirically. Yeah.
Uh should we follow the paper now or is it?
So this paper, this is Unified Laws for Routed Language Models. It's a somewhat old paper by this stage, but one of the things that they did is looked at if I keep increasing sparsity, what is the model quality impact? This answer is very sensitive to the actual choice of mixture of experts. Mixture of experts has been around for a really long time. I think it was maybe even back in 2017.
But the tech techniques have changed a lot. Deep Seek mixture of experts was was a big change in how it worked. Um there have been older papers which are G shard uh switch transformer. So The actual empirical results are gonna depend on all of that. Um but on one of the older techniques that is shown here, you can see if I hold constant the number of active parameters at a certain size, and then I increase the sparsity, which they call expert count here.
The quality keeps increasing. And then if you imagine like drawing a horizontal line from 1.3b dense across, you end up seeing that, for example, in this case, the 64 expert Three hundred and seventy million activated parameters model is as good as a dense one point three billion model.
So in some sense it's actually not amazing returns where you need to increase total parameters a hundredfold to get the equivalent of Ten X as many.
Yeah.
Back to
Yeah, I mean actually even more so. Yeah. It's a huge increase in uh parameter count for a modest increase in
Yeah. So in this case actually it's what what is it? Four X.
Sixty four X for four X.
Yeah, so n while it is while it is true, I guess. that the you get this benefit of being able to economize on your compute time if you increase sparsity.
Um
Naively it would seem like, oh, that's a trade-off worth making. But i if this is this you're decl decreasing this by two X and then having this go up by eight X. Yeah.
So is that good or bad, actually? Um even from a memory point of view, keep in mind um you are doubling this portion of the memory fetches, which is amortized by batch. And so just just keep running a larger batch size. Um From the point of view of the n the analysis we've done here, this is pure a win. Keep doing it. Um uh keep doing it until you run out of available users, basically. Mm-hmm. Um so there's actually this equivalence between uh
If I want to go sparse or if I have a lot of users, I can go to a much sparser model. So from that point of view, it's a reasonable trade-off. Um the other trade-off that shows up here is that um it also consumes memory capacity, which we we we've only reasoned about memory bound with here, but it also consumes memory capacity.
Let me just make sure I understood. You're saying We want bigger... We want uh This is one last time competing. Therefore, we do more sparsity. To make that work, we need bigger batch sizes, which means we need more memory capacity. Um
Yeah.
Yeah.
Yeah, so I mean maybe this would be a good point to actually um talk about how a mixture of experts layer is typically laid out on on a like a on a rack of GPUs or something.
¶ How MoE models are laid out across GPU racks
Yeah, where are we?
Uh sparse mixture of experts. Um maybe how we lay that on out on a GPU. So um let's zoom in on the Maker of Experts layer first and and and sort of draw what that looks like. So we typically um will have a some kind of a router layer um which is making the decision of where we route uh the experts uh the tokens to. So we have tokens coming in here. We go through a router layer and then we have a bunch of different um experts.
Uh I'll draw a draw a few more um to line some up. Um and then the router will make a decision um which experts am I going to route to? And it'll be a small fraction of them, maybe one in thirty two. So maybe it'll make a decision to route to this one, um, uh maybe this one, and maybe this one. Uh these experts so these each expert itself is a normal MLP. It has a up projection and then a down projection with a nonlinearity in between.
Um and then finally we sort of do the inverse operation. So where we were broadcasting things out here, um we're gonna bring them back in and sum them up. So Uh bringing them in like this. Uh and then finally we have our residual connection. So the the token is also passed through here and it gets added. to the result of the MOE layer.So this is a normal MOE layer.
Um what I want to talk through is how this is mapped to a like a GPU rack um and what this means for communication. Um because I think this will will start to show some of the the limits of how fast we can go. The standard practice here, and it it is the best solution, is to use um expert parallelism. So that means different experts go on different GPUs. So if we take something like a Deepseek model, um they have 256 experts. Um
Let's say we want to run that on a Blackwell rack. Um so there are 72 GPUs. Um We have a divisibility problem. This is not a power of two. Um so we'll just like simplify and say we're only gonna use sixty-four of them. Um just ignore the other eight. It's not a big deal. And so we we have four experts per GP. Very simple.
Uh for the sake of the diagram, I'll actually just say let's let's say we have two experts per GPU. So we um we end up just putting uh these are the GPU boundaries. Every pair of experts is on its own GPU. Um And then we can look at the communication cost. We had some experts stored, some tokens stored centrally here. They get routed to all of these experts. And so there is some communication cost paid here.
There's the same communication cost paid on the output. And then the hope is that this does not become communication limited. आवो Now what is the traffic pattern here? Um the traffic pattern here is that any GPU in fact will be talking to any other GPU, depending on um the decisions made by the model. So this is an all to all traffic pattern.
So when you say any GPU in the pretense Yeah.
Yeah, the router so I I drew this as one router. Uh in reality th you would actually have many copies of the router. And so you would have um as as many routers as as GPUs, in fact.
As as as the
Yeah.
Incoming g incoming traffic.
Yeah. So these are these are the these are sixty four GPUs. These are sixty four GPUs. It's actually the same GPUs. We just like draw them as separate because they're serving different. So at this point, any GPU can be sending to any any other GPU. So this all-to-all pattern of communication that shows up and how how the Blackwell racks are configured is a is a perfect fit for the communication pattern that the uh MoE actually wants to do.
Um however if you think maybe I want to do like maybe one rack is too slow when I want to do two racks, um then I have this challenge that like maybe I've got some sort of rack boundary drawn outside here like this. And I no longer in fact have all to all communication between all the GPUs in two racks. Um and so the rack-to-rack communication ends up being a substantial bottleneck.
So uh this sort of like the fundamental thing here is that one rack is actually the bounds the size of an expert layer you can do. And so uh this has been part of what's been driving towards um larger and larger interconnect domains.
Yeah.
Yeah.
It may be worth you explaining what exactly Iraq is, the differences in bandwidth between Iraq. And within Iraq and the all tall versus not all tall nature of communication within versus outside.
Yeah, and and this is a place where it starts to be very different, in fact, between uh NVIDIA, for example, and Google and then others, including us. Um The so generally uh a rack is a um It is a physical structure. And it stores uh some number of uh GPUs or XPUs, which is typically about 64.
Um the the const th what constrains it being a certain size is power delivery, weight, um, and cooling ability. Uh it it ends up being about this size in in many cases because of these physical constraints. So then when I deploy a data center, like I've got uh a data center may have thousands of these racks. So I've got one of these tall racks, it's got a bunch of GPUs in it, um uh and so on, um, and then I put another rack um next to it. Um
Did you make this answer, yes?
Yeah. I I just like drop them in. Um in NVIDIA's case, um the the communication uh topology um is Uh actually it it they put the GPUs on on the outside of the rack and then then they put these switches on the inside of the rack. So what this ends up being is that there's uh a set of switches in here. Um these are the NB switches. And then they run a bunch of cables. Um every single GPU uh has cables um going to the switches in the middle. Um
🔇 Silence
So uh every GPU goes to the switches in the middle, and then uh the switches have connections to all the GPUs. So all of the GPUs can talk to all the other GPUs uh in in just like two hops. Going to the switch, going to the other uh GPU. Now, when I want to leave the rack, I end up going via a different path. Um the GPUs have also a much slower um uh connectivity, which is typically about eight times slower, um, which is
Uh so so the green that I drew here in GPU cases is the NV link. More generally it's called the scale up network. Um uh this is the scale up network. Um You will typically also have a scale-out network, which allows you to connect to like some data center switch, say data center switch.
And then all of the GPUs will have some connectivity up to some data sender switch somewhere. Um but this is this is about times uh like this is the scale out. Um And it tends to be about about eight times slower. Hãy subscribe cho kênh La La School Để không bỏ lỡ những video hấp dẫn So the the challenge if you want to, for example, lay out a mixture of expert layer across two racks is that Half of the GPUs here are gonna be wanting to talk to talk to the GPUs of GPUs here.
And so um like half of the like just on average, like when I look at where the tokens on on these GPUs want to go, half of the tokens want to go inside the rack. That's great. They can use the the fast scale-up network. But half the tokens are going to want to leave the rack and go to the other rack, and that's not as good. They're going to need to use a much slower network. And so that becomes the bottleneck on uh on on the all-to-all pattern.
Um a different choice would be well, why don't I like have a big switch here and sort of like um and connect uh everything to some big switching. uh like a much bigger switch that actually combines the two racks together. There are many ideas in this direction, but in general it becomes uh the reason you have this sort of hierarchy of switches rather than one big switch. is to manage the cabling congestion. You just need to run a large number of cables.
Is that question you just asked basically why isn't it a bigger scale?
Yeah, exactly. Why not just like have uh like a million chips and scale up or a thousand?
What has changed that is a net allowed NVIDIA to go from Hopper was eight, then uh Blackwell is 72, and now Ruben will be I don't know is it 500? Um what has allowed that to happen?
Uh so from Hopper to Blackwell is is mostly just uh uh the decision to switch from uh uh trays as the form factor, or one of these is a tray, just to switching to racks as the form factor. That's a product decision. It's um there wasn't a substantial technical barrier there. Um Uh switching from uh from the like uh sixty four to f to five hundred or so. Um there's a bit of chance and math there, but uh uh there is at least a a genuine four X increase, um which is
Um coming from a much more complicated and difficult rack design. And so that that is actually like f new new physical design to run more cables.
And the c cable complication is just the the the the cost of figuring out which cable hops to which or like which signal
Yeah, I mean let's sort of zoom in on this and look at the the wire density. I'll draw this diagram just once more so we have a bit of a cleaner version to work with and a larger version. Um let's say I have some switches in the middle. Um and let's say I'm gonna have initially I'm gonna start with just two GPUs on e each side or two two trays of GPUs on each side. Um and let's say maybe maybe each tray wants to have uh two cables coming out of it. Um
So I get some kind of I I physically run vertical cables that look like this running onto the switches. Um now if I want to double the number of GPUs in a rack, um Uh I need to run like literally twice the density of cables. So um I need to run yeah, uh these as well. Um
Extremely naive question. But if you look at a physical data center. It seems like there's a lot of space within a rack. I don't know, just like the cables are like really big and
Yeah. So th there is space outside the rack. Inside the rack, like these racks are like I mean, a as they become more optimized, these racks are very tight. So um there's uh connector density l going from um from from f from the tray into the rack and the the rack's backplane. Um and then the backplane itself has a has has a really high density. Um there are other physical constraints including like r bend radius of cables, like you don't want to snap them and so on.
Okay, so it's literally the physical space to put a cable. Yeah that's constraining it. Yeah. I had no idea. Interesting. Uh that seems surprising that like Oh, but the the rack is so big and they're just like we can't just stuff more cables in there.
Yeah. So I mean rack design is not my expertise, but like when I talk to to folks and what are the constraints they're up against, it it's uh it's a combination of um uh so you what are the big physical things you're optimizing for? Um
uh weight of the rack, like it's actually really heavy and so like you need uh enough metal to to not sag and fall. Uh but then you add more metal and it's heavier. Um and then power and cooling. And so all of those are competing for uh like Modern racks are pushing all of those to very extreme physical limits.
Deep work is by its nature quite aversive, so even things which seem like work, like Slack and email. easy ways to distract yourself. So I often wish that I could just turn the internet off. But if I'm prepping for an interview, even if I have the papers and books on hand, it's still super useful to be able to do a back and forth on the LLM so I can break down concepts and research follow-up.
Google's new Gemop4 is the first open model that allows me to have this kind of fully disconnected focus. It's small enough to run on my laptop, but good enough to actually be useful. So to prep for this episode, I downloaded Reiner's scaling book and showed off the internet. I was able to have Gemma help me understand the material and answer my questions. If you want an LLM that you can run locally on your laptop or even your phone, You should check out Gemma 4.
🎵 Music
When was GPT four released again? It was twenty twenty two or twenty twenty three? And it was rumored to be over one trillion parameters. And it seems like only now and within the last six months have models been getting released that are significantly more parameters than the model released three years ago. Yeah. When supposedly there should have been this um
uh scaling in the meantime, is the reason that we were just waiting for rack with enough memory to KB cache for enough, you know, users for a full um for a lot of sequences or RL if you're doing RL, kind of a similar consideration of actually holding the KB cache for all the the uh the the the batch of problems you're trying to solve. Um so if you look at like hopper, you had eight hoppers and I think the
That's six hundred forty gigabytes uh as of twenty twenty two. Yeah. Um with Blackwell finally, which was deployed what, twenty twenty
Very recently. Maybe last year.
Last year, you finally have a scale up with on the order of like 10, 20 terabytes, which is enough for like a 5T model plus KB cache.
Yeah. Deploying in in larger scale up domains is a huge unlock. I've drawn here the sort of NVIDIA Blackwell deployment. The Google deployment has actually had very large scale up domains from
That also explains why Gemini was seemed to be ahead. Like was Gemini 2.5 was a successful or it just seems like Gemini has that successful pre-train for longer than some of the other labs.
I not having been there at the time, I'm not sure how much is coming from like successfully deploying higher sparsity ratios, which which could be. Um it could also be I mean, there's a whole bunch of actual modeling things of like uh Specifically, how do you do the mixture of experts? Uh we've seen the um
Deepseq uh like the Deep Seek Mixture of Expert has said actually activate more experts but finer grained experts was a big innovation. Um I'm sure then there are many other innovations on the model architecture as well as on the training data. It's kind of hard to disentangle all of them. But uh what shows up in terms of the limits of what you can do, um The the active parameters uh as we saw is limited by the compute cost, um, and then the total parameters is l uh limited by the scale-up size.
¶ How pipeline parallelism spreads model layers across racks
Operating within a single scale-up domain, is that a consideration specifically for either forward or backward? Or specifically for pre fill versus decode.
Or
Or is it is it preferred to always be within a scale up? Whatever kind of workload you have, whether you're doing a pre training run or whether you're doing RL uh generation or whether you're doing inference for users.
Yeah, really interesting. Um so Okay, so uh to answer that question we're gonna need to talk about the communication patterns. Um that so we've talked about the mixture of expert communication pattern. That is this all-to-all. Um uh this all-to-all. All to all. Um to all very strongly favors um uh f full connectivity, which is what what we've kind of just shown here, and favors being within one rack.
Um there are other kinds of parallelism uh besides expert parallelism, which which which which we just showed here. In the literature is tensor parallelism. This is um with the trend towards smaller experts, this has become much less relevant, so we can ignore that. Um but the other two things that we have available are data parallelism and pipeline parallelism. Um and they are actually much they can be a much better fit for uh using multiple racks.
So let's focus on pipeline parallelism specifically. Um this is one layer of MOE. Um I'm gonna have like a hundred more layers up above. Um I could decide at this point, for example, to move to a different rack, uh change rack. Now, is that going to become a communication bottleneck?
We can actually just solve for when this becomes a communication bottleneck. But before we do that algebraically, like let's just sort of visualize it out and sketch the path. So we're gonna have a bunch this is another MOE layer, and we're gonna have another MOE layer here, and so on. Uh so let's say I change rack here, and then some number of layers later I change rack here as well. Um
🔇 Silence
So our our our methodology that we're gonna use to determine whether we have a communication bottleneck in this like in this point where we change rack um is we're gonna compare the this this is the scale out. Um um bandwidth requirements to the scale up bandwidth requirements. So let's try this uh and and I mean the the hint is gonna be that um
There's a lot more trend sends here, like we're sending many things here, whereas we're only sending one thing here, and then we're also maybe doing it many times. That's so that's going to be the uh what what makes the difference.
Uh c can I try to guess? Just out of curiosity to see if I'm actually understanding. Um it seems like you're you're sending like batch size into the rack. Uh but the communication within Iraq is sort of batch sized. times number of GPUs.
Yeah, so number of activated GPUs, right? So like I I don't send to this GPU at all, right? So there's an explosion from one to like Three times larger here in in this diagram. Yeah. The key thing is that I I didn't even need to send to this GPU at all, and so that's a big saving. Okay, so we're going to talk through um uh sort of how much
more uh c w what is the slowdown of uh to what extent is scale up uh a bottleneck over scale s uh over scale out. So uh we will directly jump to the ratio of the time spent on Time on scale up. over the time spent on scale out. So this this is the quantity we're talking about. And the first consideration is that the scale up is like um uh scale up is is eight times faster than scale out generally. And so uh at a baseline if the bandwidths were the same, we would have this.
One o one over eight, which is coming from bandwidth. Продолжение следует... we have some amount of expansion in in in how much data we're sending. So if one token comes in here. Then this one token gets routed to in the deep seek case it'll get routed to maybe 32 experts or or sixteen experts. It gets routed to some number of experts. So this is the number of activated experts, number of activated And then it also...
This uh the same thing applies on multiple different layers. So maybe I'm gonna run two layers. So um there's also multiple times um number of layers uh
And don't you need to multiply the whole thing by two for the um for the alpha down?
Yes, yes, and has a factor of two. Thank you. Um What we would like is the for the scale up time to be greater than the scale out time, um, because like the scale up time is the more important and precious resource. And so we just we want this one we would like this number to be greater than or equal to one. Um and this really doesn't seem hard. Like we've we've there's just a factor of eight that we need to overcome. So we need the product of these three things to be bigger than eight. Um
Typically we have a fairly large number of activated experts. It could be eight um by itself. Um and then we can increase the number of layers per stage a lot until until we satisfy this. So what this ends up looking like is that I can in fact have an entire pipeline of racks where one rack does one layer and then I move on to the next rack and I do another layer and then I move on to the next rack back and do another layer.
It's interesting to me that the best parallelism uh strategy and practice ends up being one which physically resembles the actual architecture. It's not some galaxy brain thing, you know, it's like, oh, we have experts, we're gonna put them on different GPUs. Oh, we have different layers, we're gonna put them on different racks. Isn't that I feel like that's interesting that the physical and Yeah, exactly. Yeah. I mean it could have been something wackier with tensor parallelism and whatever.
Yeah. So I mean, uh I think a way to think of it is I mean, okay, the Galaxy Brain way to think of it is um like what are all the different dimensions in which a model is scaled up? Um and so there is uh it is scaled up by layers, it is scaled up by the like D model d uh dimension, it is scaled up by the DFF dimension, it is scaled up by the number of experts. Um every single one of those numbers you can choose to cut along.
Um and if those numbers are big enough, it eventually becomes profitable to get along there. Um and we have selected two of them. The other two, in the way typical s uh models are typically sized, are not profitable. Um
So th there's um talk by Ilya where he says, Today we know not to do pipeline parallelism and uh Horace E gave my friends and me I hate that it sounds like a doctor sous quote. But he gave us a lecture on uh these different kinds of parallelisms and he said the problem with pipeline parallelism is that it uh other than the bubbles, it constrain it creates these architectural constraints. Yes. On um Like Kimmy, for example, has these uh residuals where our attention attends to the
A few are back.
Yeah, layers a few back and so that it becomes hard to implement in this way.
Yeah. Um so and I guess we didn't really uh fully articulate even what is the benefit that we're getting from pipelining. Yeah. Um Uh and so these complexities are real. It's uh pipelining is a massive hassle. It's uh but it does give you some benefits. Um Uh and then you can then decide whether those benefits are the worth it or worth the costs. Um the uh the biggest benefit that shows up so it can has some benefits in influence, maybe bigger benefits in training. Um inference
What are we saving on? Are we saving on um memory time or compute time? Not really. We're just moving the memory time from one chip to another chip, um or one rack to a different rack. There's no actual benefit in runtime. However, what we are saving on is that the memory capacity is uh the amount of memory used per rack. If we think that the memory in a rack is a bottleneck, then there's a constraint on how fast we can go. Um it pipelining allows us to massively reduce that bottleneck.
I guess not the... Opposite connotation to this. Which actually uh before this I was chatting before this uh interview I was chatting with um Axel, who's a GPO performance engineer at uh Jane Street and he uh he was explaining, well, to do pipelining you had to do micro badges rather than full badges. And if you do micro batches, then you're by definition not able to amortize The weight.
That's right.
all the users or all the sequences. And so The positive connotation not that is you don't have to use Stress memory. The negative connotation is that of that is that we can't amortize loading the weights across all those users.
Yeah, well so we draw the m the pipeline bubble. Yeah. Okay. So so why do we do um uh what what is this micro batching that shows up in shows up in pipeline parallelism? So um The uh I'll focus on inference first. It's it's a slightly simpler problem. Um so and I'm gonna draw uh so this is time, um and then this is which rack uh rack.
um where on. And so the idea is that l maybe I'll I'll have like four racks. So I've got um uh an inference that is gonna like step through these four racks in some time like this. So great. Uh it runs at a certain batch size uh and it steps through all all the pipeline stages like this. Now, if we were to say, well, we're gonna run inference number one here, like this is clearly like a massive waste, right? Like um like three quarters of the time th each of the racks is doing nothing.
So um so so we don't actually run inference one here. We we we run it as soon as we can, which is immediately after um inference zero finishes like this.
Um
So uh and then we keep going. Um so if we hadn't filled this in, we would call this the pipeline bubble. Um when I've drawn it in this inference context where we're only going in a forwards pass, it's like obvious, like why would you do this stupid thing? But in a training context, uh it's maybe less obvious. But in the inference context it it's it's sort of really natural to to make this change.
Oh interesting. So uh this this s s sort of obvious, but um The difference between microbash and bash doesn't matter at all in inference because You can just call whatever you want, whatever. Yeah. It it it only matters in training because there is an optimal batch size.
Yes.
And before you do the backward step, you wanna have accumulated before you do a full backward step, you wanna have accumulated all the sequences in that batch. And if you wanna do pipeline and
Training.
In order to avoid that bubble. You need to do that.
draw the the training diagram with that. Yeah, let's do that. Let's do that. So this is the inference diagram, and I'll call this four just so we don't have the wrong thing showing up there. Um so let's do the same thing for training now. We've got a forwards pass, but at some stage we're gonna have to transition to a backwards pass. So we'll we'll do some number of uh batches in the forwards pass.
🔇 Silence
And then we're going to transition to the backwards paths for everyone all in one go.
🔇 Silence
So the the inference part is the same uh here, but then we do a hard stop at this point and then transition everyone to backwards pass. Um similar numbering like this.
It may be worth clarifying the reason there is that hard stop is because you want to do a whole batch at once for the backward step. And then there is an optimal size for how big that batch should be.
Yeah, I mean...
Smaller is always better, actually, is is is is a way to put it. But uh it's a like from a ML convergence rate perspective, smaller is always better because uh basically you're getting the freshest information from from from the gradient descent.
But total trading time perspective.
Total training time perspective, it's wor like smaller is worse from a systems perspective. And so the optimum is the trade off between those two. Yeah. So you pick a batch size um and you uh and then like for that batch size you you do some amount forwards and then some amount backwards. You asked why why is there even a hard stop? The pipeline parallelism, because of this the like the fact that you've got this idle time here, which is the the bubble. Um
There are so many techniques in the literature for how to um lay this out differently and and avoid that. There are more complicated schemes called like zero bubble or one forward, one backward, which sort of interleave the forwards and the backwards in complicated ways. But uh Right, right. More usefully you can do the weight gradient uh step, but uh but you can also make it gradient. Yeah. So in inference actually the
The effect of pipelining on anything you care about, like batch size or latency, actually is neutral. It doesn't improve it, it doesn't make it work. So if you look at the latency of this inference running it if it were pipelined versus if it were all on one rack, if it were all on one rack, we would just like slide all of the boxes down and still put them in a row, and the latency would be the same.
So um pipelining is neither better nor worse for latency. Um but it it does mean that you just use less memory per per. like memory capacity because now instead of needing the whole model you only need a quarter of the model in H.
Makes a ton of sense. So basically, no brainer to use pipeline during inference. But there's this harder trade off during training.
So so even in inference, in fact, it is not used a ton. Um it say it reduces your m memory capacity requirements. Um there's actually a huge surplus. Like um, I think you're saying that a a A a rack of black well has many, many terabytes, maybe tens of terabytes of of uh that's much bigger than um like a trillion primary model. A trillion primary model is only needs one terabyte. Um and so it already fits, in fact. And so there's not a huge benefit from
Um from pipelining because you you're reducing a number that's already pretty small. But it does say that theoretically maybe you had too much memory there. And uh maybe you could have done a different uh like build a different hardware that has less memory, in fact.
If you were designing your hardware like and you said, I actually didn't need that much memory because um I don't need the weights to fit in one rack, I can fit the weights in eight racks, um then uh I could have maybe built a hardware that didn't have so much HPM per GPU.
Last week, Horace was kind enough to give me and my friends a great lecture on large-scale free training systems. And there were some concepts that I wanted to animate for a write-up on my blog, like how weights shard and gradients flow depending on the parallelism that you're using. So I gave Cursor my lecture notes and a sketch that I'd made during the lecture. And I asked it to visualize a specific hierarchical collective that Horace had explained.
The first version was already pretty good, and then I was able to use design mode to select and tweak any specific components from there. I was able to do all of this without a clear end state in mind. Cursor's composer to fast model was quick enough that I was able to iterate almost instantly. I could try an idea, test the results in the built-in browser, and immediately make any change.
I went through ten different versions in under twenty minutes. If you want to check out this animation, I published it along with the lecture notes in a blog post. The link is in the description. And if you want to try out this kind of iterative design flow for yourself, go to cursor.com slash lowercash to get started.
¶ Why Ilya said, "As we now know, pipelining is not wise."
So, macro question. Everybody's talking about the memory wall right now. Memory's getting super expensive. There's not enough memory. Smartphone volume will go down thirty percent because there's not enough memory. Hyperscalers are spending well, this is shocking. If I'm Dylan said they're spending fifty percent. Of their capex this year?
That's believable.
What is hyperscalar CapEx? That's like high hundreds of billions, maybe a trillion, and they're spending half of that on memory. Okay. So that that is a huge constraint. That's why we're not gonna get new laptops and phones this year. Um but At the same time where we have too much memory? Like people are willing to put too much memory into these systems?
Right.
Like why why why is uh Jed Set shoving all this memory into these racks if you don't need it?
Yeah. So we've like in in the um equations we had here before we raised them, we were doing memory time, so memory bandwidth and can compute bandwidth. Let's now start looking at uh memory capacity. Yeah. So we'll start off with just like memory capacity without even thinking about parallelism scheme. And so the um uh like the capacity of memory um or the or or the the demand on memory is um the number of total parameters
Plus, so this is what we need to fit the weights in some system that we are using. Um and then we need to fit the KVs as well. So KVs go as batch size times the length of the context. Um times uh times the bytes, bytes per bytes per check. Um Okay, so, um...
What I was arguing about in this context and the case I was making uh for pipelining is that um we will actually there are some techniques that allow us to solve this. Are there techniques that allow us to solve this? So let's let's consider So we're gonna run this on some number of GPUs and and we're we're gonna say um we're gonna have one extender which is um uh E is gonna be the expert parallelism.
So how many, when we had this charting of uh uh expert layer across many GPUs, how much of that uh to what extent do we do that? How many GPUs? Um So we're gonna say that this is fact, for example, 64. And then P is gonna be the extent of pipeline. 拜拜 Um and so this is a number of racks, which who knows, maybe maybe we'll do maybe we'll pick four or something like that. Well we want to calculate, so this is the this is like the total um total memory re requirement across the system. Um
But now I'm going to calculate a um a memory requirement per GPU. So per per GPU memory requirement. Uh we're gonna have, I guess I'll use a lovercase c mem. And well obviously we just take all these numbers and divide it by E and P, really easy. So um uh it's this n total um plus the batch times length of context. But all of this is divided by the speed. Okay, so this is like why is this correct just divide it this way? Um well
We're we're we're saying we we knew that the parameters were perfectly divided amongst all the the GPUs in a rack. They're also the layers are perfectly divided amongst the the the different racks. So that works here. And somehow we're going to arrange, I'll hand wave exactly how, somehow we can arrange the same perfect sharding of of the contexts across GPUs in a rack and and and based on layer across uh racks.
And sorry, four is the number of racks.
Uh yeah, for example. So, um... This is the place where we actually need to go back and analyze this batch size B. And you were making this comment that there's micro batching versus global batching. So um let's come back to this pipelining diagram here. Um we've got one batch going forward here. And then as I drew it, it kind of just like disappeared. That's not really correct. If you think about um how decode is working, I have a bunch of tokens that I have generated already.
I do one forwards password, I generate a new token. And then and then I c push like then I write that to my KB cache and then I do another forwards pass that generates the next token. I'm actually going to be running this batch zero in a loop. So in fact I go forwards, once I finish, I can start the next iteration of the loop up here. And so we'll just fill this in.
🔇 Silence
Nice.
Yes. Um yeah, so we've got the t two or three. Uh two and three. Uh two or three. Uh So let's split this batch. This batch will be the global batch size. So B is going to be the um number of number of micro batches. times the So How many micro batches do we need? So the number of micro batches in this diagram is four, zero, one, two, three. Um and then the batch size per um like the the micro batch size. This is still this like two thousand ish number. Um this is the one that is like um
Mm-hmm. This is the like two thousand um times sparsity. Uh sorry, uh no, th this is the three hundred times sparsity. Uh three hundred times sparsity.
This is this is the how big the train that takes up very shortly milliseconds is.
Right. Yes. Th this is going to be the the twenty milliseconds uh train. Um So the global batch size is the m number of micro batches times the local batch size. Local batch size is set by this hardware parameter. The number of micro batches Um well the number of micro bashes is as small as possible such that we can like wrap around uh and not leave any idle time when we wrap around. So if we like if we had fewer, they would we would have had this idle time when we wrap around.
And so you can sort of just visually see that it is equal to the number of pipeline stages. I mean, sort of proof by visual here. Like it is four and it's four this way as well. But like in you can sort of look and see that it goes along here and then it wraps around um number of pipeline stages.
This is what is actually done. Okay. Like as in a Frontier model today, we'll actually have and during inference have pipeline.
Uh for sure during massive scale training, this is done. Um it can be done for inference. I'm actually gonna make the case for why it is less attractive. It is useful for weights, but not so useful for K. Yeah. Yeah. The big challenge is, so like let's fill this in. The micro batch size here ends up being equal to the number of pipeline stages. Yep. When we go back and substitute this less all of that into here.
🔇 Silence
We get a um number of pipeline stages times um this little b. Showing up in here. And then when we factor this out, I'm gonna split this into like this plus into two terms.
🔇 Silence
We get the full division by E times P over here. We still have division by e times p over here, but the p's cancel, this p and this p. They cancel. And so what we find, if you increase the number of uh pipeline stages, the memory footprint for the number of weights keeps going down and down and down, but the memory footprint for the number of activations stays constant.
So so it it it doesn't actually work. Like uh most of your memory um ends up like once you do enough pipelining and it's really not much, like even two is often enough. Um This term becomes very small, this becomes the dominant term. The the KB cache becomes the dominant term.
I I know this is wrong. I'm just trying to think out wh why we portrayed a logic here as wrong. If you have many different um You're pipelining through many different stages. The KV values are not shared between layers. So why would it not help to be pipelining across multiple layers? Because then you don't have to store
Yeah, you only need to store like one layer rather than two layers of KVs, right? Yeah. So so it helps from that perspective. You're right. What's competing with that though is that you need to be keeping all of the racks usefully busy at a time. And so the number of sequences that are in flight simultaneously has gone up.
Uh yeah, yeah, it makes sense, makes sense, makes sense.
So those exactly cancel and you end up not getting a saving by GP.
Right. This is going back fundamentally to the point of you're you're not able to amortize across KV caches.
Yeah. Well so w first we did you can amortize KV caches across batch size. And now we're saying you also can't um shard it across pipeline stages. Um uh it it it sucks from both of those points of view.
Okay, so then what is done during inference?
Um so I mean A like the Deep Seek paper reports what they do, which is like um they just do a lot of excerpt parallelism. In effect, you should increase your expert parallelism up to your scale up domain size. And then do very little pipelining. Maybe none at all, maybe two, um, just enough to make the weight storage not not too big of an issue. Um those are the only two parallelisms that really make sense in the past.
Um there was tensor parallelism, which was make cutting up within an expert, but uh the experts are so small now that that that is not a profitable optimization.
So this goes back to the question, does that mean that Frontier labs when they're doing inference are just basically within the single scalar?
Uh yes. Yeah. I mean you can look at how it depends on model size. Um like You could have a very large model, like um Like one that exceeds the memory of a rack. Um and and and there you should be doing a bit of pipelining. Um maybe maybe it's extremely sparse, for example, and that would be a reason to do it.
Um so I guess this goes back to the question about uh or this goes back to the promise at the beginning of the lecture, which was this will actually tell you about AI progress as well. Um to the extent it is the case that model size scaling has been slow until recently because Let me make sure I understand the claim. The claim would not be you could have trained across more more racks.
It was just that it would not have made sense before. Like we didn't have the ability to do inference for a bigger model.
Actually I I make the cla so pipelining doesn't help with context length. It totally helps with model size. And so um because of the ability to do pipelining uh Uh at least a rack should not be a constraint on your ability to fit the model parameters. I guess the other consideration you're asking like why hasn't it scaled up more and why did bigger scale up domains help? Um so we we talked through one aspect of that which is um we kind of said it
It's not because of m memory capacity. We we have a solution to the memory capacity at least with respect to model size. Yeah, it's just not with respect to um uh KV cache size, but at least with respect to model size, we have a solution to memory capacity. Um The other issue that shows up is uh latency.
I was just about to ask. So what is the going from rack to rack? What is the latency cost per per hop?
This is very much dependent on the hardware. I would uh I can't say with a lot of authority. I think it's probably on the order of a few milliseconds, but it could be off by an order of.
For a realistic number of how many pipelining stages you might have?
Yeah.
Okay, so that's...
On a small number of pipelining stages, this is not a huge um uh latency.
Wait, I guess it's 10 milliseconds per token. Two times four ish. Or I don't know how many said, but
Yeah, yeah. Ten milliseconds per tokens is actually a lot. Yeah, if it if it goes from twenty to thirty, right, or something like that. Yeah. Yeah. Um this is so like just to to chart the path that it goes through. Um Here you're going from your from your GPU or TPU or whatever to a network card, um uh which then goes to like a top of rack switch. Um
And then hops over to the other act and does the same uh same thing in reverse. So you sort of have to sum up the latencies of these different things.
Sorry, th this is the same thing as the D C
Yeah, yeah. It may in fact go up to a data center switch and back. It depends on your deployment configuration.
And and because it's um Decode and sequential, it's also not the like the they stack up across the stages. You can't do them at the same time.
Yeah.
Okay. So I I guess this brings us back to the question then. Is the size the scale up at all relevant to why AI model sizes or whatever have been what they have been over the last few years, whether whether whether through training or through inference?
Yeah. So I mean we talked about latency of the hop, um of the of this hop. Um there is also just the the same T mem latency, the the m t memory time latency is actually substantially like massively improved by larger scale updomains. So um I'll I'll recall TMM down here. Um T mm for the weights, uh Tm of weights. Um This was equal to the number of total parameters. divided by the memory bandwidth?
Which memory bandwidth are we talking about here? Is it just one GPU? Or it's it's it's in fact It it is the number of GPUs that I can use in parallel to to load this weight. So um I can't use different pipeline stages in parallel because they they're not running at the same time, but I can use all the GPUs in my scale-up domain in parallel to load the weights. And so um
This is actually extremely effective. Um so uh basically I end up with a term here, this this memory bandwidth term itself is equal to um like scale up size.
Tens of memory bandwidth per GPU. Yeah.
Yeah, times CPU bandwidth. And so this term doesn't increase a lot. It maybe increases 1.5 or 2x per generation. But this one increased by like a factor of eight from these from halfway.
So the reason the bigger scale up matter is not the memory capacity of the whole scale scale up, but really the memory bandwidth.
Yeah. Yeah. Pipelining totally solves the capacity problem, but um but uh uh scale-up size helps solve the penwill.
And the bandwidth problem helps you do longer. context lengths, which is more and more relevant as these models get more genetic.
Yeah. It lets you just run the model at lower latency. Um uh as a first thing. Like if I just do a very fast model and it's on like a little like H one hundred uh box, um uh the latency will will be really high.
¶ Because of RL, models may be 100x over-trained beyond Chinchilla-optimal
There's chinchilla scaling, which tells you how how big should a model be relative to the amount of data you're gonna train it on. Um But now obviously you're not just trying to optimize for the highest quality model you can get with training compute. You want the best results a user can get with a mixture of training and inference compute. So then there's a question of how much should you overtrain a model such that that compute amortized overtraining and inferences.
minimize to get a certain performance. But now with RL inference, there's or R R L there's another consideration, which is you're gonna do some amount of pre-training. Yeah. That pre-training will be used both for RL generation And then for inference for the final user.
And by overtraining here I mean while it would have been more efficient just from a training computer perspective to have a bigger model that you train for less time because it can learn faster, maybe you you get a smaller model, you spend more computer training it than you otherwise would have. But now it's cheaper to give it to users.
Like basically, okay, maybe maybe maybe but let me make a coup question more concrete. How much more than chinchilla optimal are models over trained? Yeah. And has that changed as a result of RL generation?
This is a place where we have to do a bit of guesswork because like the um the updated scaling laws and and the use and the model traffics are not reported. And so we have to guess there. Um but uh one way to look at it um
🔇 Silence
Let me f first just make a sort of a general heuristic claim. If I am if I had some like cost and I've got a total cost, which is a sum of like cost A and cost B. Like maybe this is the training cost and this is the inference cost.
Yeah.
Um and so I want to minimize this sum. For many uh for many curves that tend up being the case, the minimum tends to be where these are where the costs are equal. Um that's something of a heuristic claim, but uh you can you can it tends uh like there are many examples where it's true, like uh where one is one over x and the other one is is x, for example. Um they tend to be uh minimized at uh at the point where uh they equal each other.
Um it's also true for like um e to the x and like e to the minus x and all all kinds of other things. Uh like so basically I've got some I've got some curve that's going down, some other curve that's going up, and they tend to be minimized at this i e equal point. Um Heuristically, I will conjecture that that is true for the setup you described as well.
Um uh like actually showing that that would be true would require looking at the scaling laws and um and like fitting these like weird exponents. Um but but things that do follow power laws tend to tend to have this property. So I'll just make that claim and move on. Um We're gonna say that the uh cost of training plus the cost of inference. We want to equalize these.
Uh we'll do pre-training only first, because it's a little well actually we we we can do all of it in general, so so actually we'll we'll cost it as. Um cost of pre-training. So number of uh number of number of active programs. um times the data on pre training. So that's the cost of pre-training. There's a factor of six out here, which is the number of flops. Um there's the famous six ND formula. Um
And then in in RL we have approximately the same thing. We've got like same number of active parameters. Um but now it's uh the amount of data is the RL data. Um there's this extra like efficiency multiplier, which is um or inefficiency like the um the inefficiency. Um
W which is the fact that you're not trading on all your rollouts.
Well yeah. And then the other perhaps even bigger inefficiency is that Um this involves a substantial amount of decode and often decode runs at uh less MFU than than than training.
Okay. So if you're doing a backward pass on every single generation in RL, it would be six N D.
Yeah. So this could be a smaller number, right? Like this could be somewhere so um
It would be it leaves me two.
Somewhere in the range of two to six. So we'll just like we'll say somewhere into the range of two to six and leave it at that. Yeah. And then and then we can add in the inference cost. The inference cost is two number of active times the data in inference.
And sorry, I think I think the way I said it was super garbled for for or just for the audience, maybe. Forward plus backwards per parameter is six. Mm-hmm. Forward alone is two. That's why RL where you might you're definitely going to generate all the trajectories, but you might or might not train on all the trajectories.
Yes. Yeah. Thank you. Um and then inference is is is just two.
So
We're going to solve for essentially it may be equality of all three of these terms, that is ballpark where people are going to be. Like uh labs have more information on on what is productive in doing more RL, for example, than versus doing more pre-training. I don't have that information. But I think a good ballpark is thirty thirty like uh thirty three percent split between each of them.
Actually, I'm not sure I understand the intuition for that. Um Another naive model could have been that RL plus pre-training would be fifty percent. And inference would be fifty percent.
Yeah, th that's also a valid uh answer as well. They uh uh because this is heuristic, I can't really argue for one versus the other. They don't differ by that much, like thirty three versus twenty-five is is only a small vector.
हाल
So uh so let's pick one of them. Uh all equal seems uh simple enough. Um Um and so we're just gonna solve for equality of them. It's pretty straightforward. We can immediately see that the number of activated parameters totally disappears. And so let's factor that out. And we're gonna just say that uh data in pre-training.
I decided to do it your way. It's a little bit nicer actually. So data in pre-training plus um this uh oh I didn't have the inefficiency over here either. Um Inefficiency, um, data in pre-training plus um some multiple of like uh alpha times the data in RL. is just gonna be and and end up equal to the um some some beta times the uh data in inference. Um So uh and then let's just p like roughly size the alpha. This this this alpha it's gonna be um Uh this is like the
It's maybe somewhere in the range of two to six, uh two to six over six, um from this term compared to this term. Um and then we've got an inefficiency term, which uh I would say is maybe in the range of like thirty percent or something like that. Um So uh so this alpha is gonna be something like um one on ten.
one over ten, let's say. Um and this beta here is is actually the same. It's it's a third, it's one third times thirty three thirty percent. So it's also um equals one in ten. Something like that.
If if both of them are one intent, that kind of implies that there's never a backward pass on RL.
Yeah, okay. We can make this like two and ten. Make it a bit bigger. Yeah. So yeah, like just write it out once more, like this is two and two over ten, this is one over ten. Um so The number of inference tokens you have, and this is just a function of like I've got
Hundreds of millions of tokens per second, um times my model is deployed for I don't know, two months before I shift to shift to the next version. Um that should determine the um the number of uh uh uh tokens in in RL and pre pre training and then I guess we didn't do the equivalence between pre-training and I RL, so we'll do that here. Data pre-training should be equal to like two over ten times data in RL for them to be cost equivalent. Um so
Sorry, this one over. I got it backwards. Uh like we pay more cost when it's inefficient. So it's this needs to be one over. Um Uh um so this tracing this back uh back forward.
Next yeah
Um this this thing ends up actually being as written here, it's like uh uh yeah, so this is like 1.5 and and this is one. Um
Yeah.
Billions of dollars of the compute just flowed the other direction.
Right, right. I I think like if you do it with a spreadsheet and like actually model it out. You might notice w when the money's going down the drain. Yeah, yeah. Um so uh yeah, so I think this yeah all of these end up being close in as modeled here. This 30% may have been a little bit too generous. Um so let's say something like 1.5 here and and leave this as a one.
Sorry.
I think it like at this point you can almost read it off. Like the number of inference tokens should be about the same as the number of pre-training tokens should be about the same as the number of RL tokens um within like factors that we're we're not able to reason about.
But then so it's it's looks uh but sorry, I'm making a basic algebra for a second it sound seems like there should be less RL tokens than pre-training tokens.
Yes, th that's in general right. Because uh RL is less efficient um in terms of machine time, and so uh you Um if you're trying to equalize the RL and pre-training time, then then you should have fewer tokens in order to have the same wall time.
That I th this is all uh quite interesting that um I never thought about it in terms of how much Equalizing in terms of data.
I I I mean I think starting with equalizing in cost is right, but uh depending on how you model the cost, this comes close to equalizing in data.
that if every single user who uses basically if you f for G GPT to be trained optimally, every single user who uses GPT five, the total amount of tokens that they stream should equal the amount total amount that have gone into pre-training.
Yeah.
And the total amount of tokens that I got in pre-training is the sum of all human knowledge. So like each model should generate The sum of human knowledge on the output that it gets on the input.
Yeah. So I mean which way are people gonna err? Like uh if you think that people's power of prediction is not perfect and and also um you run the risk that you're um that you make a model that is not a frontier model and and then you just throw it away. Yeah. Um then then like that kind of changes the c cost trade off because there's some like probability that applies to the inference. Yeah. And you should derate the inference tokens by some amount. Right.
And then I aga can we back out how much more compu uh uh yeah, compute than Chinchilla optimal for a given sized. प्यारो
So I think we just have to make some real world assumptions here to in order to do that. So um So the inference tokens we should totally be able to count, right? Like so um let's say a few hundred million, I don't know, maybe it's like uh five hundred million tokens a second now. I don't really know. Um 500 million tokens a second times a model is deployed for two months before it becomes obsolete. I don't really know. Uh I can't do this in my head. Can you type it into a computer? Um Uh
Two point six times ten to the fifteen.
Okay, two point six uh times ten to the fifteenth. Okay. Um this number is probably too large. This um because this is gonna be multiple models in a family. We so let's uh let's make it like Five times smaller or ten times smaller, or something like that? Um uh Okay, so we're estimating maybe fifteen million tokens per second per per specific model. The model is live for two months, um, and so uh this comes out to around 200 uh trillion tokens.
Um and then we want to compare that to active parameters on a um frontier model. I don't actually know the latest rumors, but um some
Somebody told me a hundred and fifty trillion.
Active friends. Trade on 150 trillion tokens, interesting.
Which is similar. Yeah.
Yeah, those are actually similar. So um so data on pre-training.
This is not but well cited, but
You want me to not remove that?
Okay.
Um and I think often active params, uh number of active params could be in the range of like uh a hundred billion, something like that. Yeah. Maybe maybe a bit larger. Um Uh so I'm assuming active frames of about a hundred billion, and so multiply by twenty to get the chinchilla uh token count. So chinchilla, d chinchilla, would be around uh two. And yeah, and we see like we're at a hundred times larger than uh.
Actually what what does ditch in shall actually mean? Oh I see. How so how much is it over trading?
Got it. So yeah, like the ratio of this hundred two hundred trillion or a hundred trillion parameters over uh over the like the Potential optimal of of two trillion. That's the amount it's overtrained, which is like a factor of hundred overtrained.
Okay. So if you consider this right here, to the extent this isn't the right ballpark. just by thinking about, okay, you kind of want everything to be equal in terms of compute. Um here's if if that open AI also realizes that and they're serving a certain amount of tokens per second.
That tells you how much data went into the pre-training of G B D five. It w even if it's like fifty percent off or something, that is that is sort of wild that you can sort of first principles these kinds of numbers.
This is also I mean, this is why you should just like approximate everywhere because like there's so big error bars on this. But yeah, no, it's kind of like empowering to just like set A equal to B and figure it out.
Yeah, yeah. That's super cool. Okay, so um in this word of trying to deduce things.
¶ Deducing long context memory costs from API pricing
We can publicly look up the prices of the APIs of these models. And um maybe you can learn something from that. So uh first with uh longer concept.
Context.
Um Gemini 3.1 is Um fifty percent more expensive if you go over two hundred K tokens than if we're below two hundred K tokens. I mean... At a high level I understand why that might that be, but why specifically fifty percent?
Yeah. Um so I mean why specifically fifty percent? Let's let's sort of um so so the high level, uh even in the first place is um There is some amount of uh increase in cost with with context length. And l uh w we can bring that back up. That was the um
The the memory time versus the compute time. So um Okay, so we we've put up these same equations from before of the the time for memory fetches, which is the weights and and the KB cache, um, and then the the time for the compute, which is just the uh matrix multiplications for the weights. I will I will also draw the um the the cost curve.
🔇 Silence
Um but this time I'll do it as a function of context length uh instead of as a function of batch size. Um so this is time over uh yeah, just time. Uh and so this is the cost curve as a function of context length. Um
We'll draw the compute. Um the com the the cost of the compute is actually constant as a function of context length. There's no dependence here on context length. In reality there is some dependence, but it is very mild dependence, so we'll ignore it. Um so this is the um Time for the computers.
🔇 Silence
Uh and then we'll also draw the dependence uh of the memory fetch on on context link. And this starts at a large number for the weights and then grows gradually with um with the context length. So uh maybe here, um, and then grow gradually with context length.
🔇 Silence
And so you take the maximum and you see there is this inflection point here. So now, so this is the costs that uh that that for example Gemini might be paying. Um and then you think how how how might you put a pricing structure on top of that? Um you would like to ensure that no matter what the context length is, you are you are still profitable. So
Interesting.
And so we've got a two-tier pricing structure, maybe we've got something that looks like this up to some X complex.
Fascinating.
So I think it says something about um given that the bump is at 200k, it probably means that this is somewhat aligned with this con crossover point, maybe not exactly aligned with. Um so we can actually probably even complete that calculation just to see where it lands out. Um we can solve for the number of bytes per token if if if we sort of make some assumptions about the number of active parameters.
So solving for the number of bytes per token, um we're gonna assume like the the point where we equalize um the time of memory and the time of compute is at let's say 200k uh tokens. So we equalize these two. We're also going to just uh assume that the batch size is large enough that the um the memory time spent on weights is is negligible, so we'll forget about this.
and we'll focus on the actual memory time spent on KB cache. So That ends up saying copying this term over batch times land context. Times uh by the same. A token, um over member bandwidth. is gonna be equal to uh
Yeah.
Now we're activated primes. Over fonts.
🔇 Silence
And then we're going to solve the bytes for token.
🔇 Silence
Match size was missing here. shows up here and then it cancels out by the time we get to here
🔇 Silence
And uh and I I dropped the land context.
🔇 Silence
So we can plug in numbers. This number, this is this is this, well, is the reciprocal of the number that we saw before? It's yeah, this is like one over three hundred, um, which is reasonably stable across many um different hardware platforms. We conjecturally said that maybe number of activated tokens is like a hundred billion. And length of the context we said was 200k. Something is wrong here. The length of the context should be on the denominator, not the numerator.
🔇 Silence
Um one six six seven. Like about one one kilob almost two kilobyte.
That's so th that is plausible actually. Um so you said around two kilobytes. Um So l let's just do a sanity check for this, um for what this could be. Um there are two mechanisms that people do uh attention with a small number of bytes per token. Um One is uh denser tension with a lot of reuse across layers. Um so character AI has a blog post talking about that, alternating long and short context. And like in the character AI kind of model, uh, which also showed up in the Gemma model.
The global context, which is really what we're talking about here, global context, um was shared across all the layers. And so to get this two kilobytes, you could get that, for example, as um a d head of 128. Um is is typical. Um and then like the number of bytes is typically um number of attention layers um uh times Two times D head uh times uh number of uh q heads.
So um this is the number of unique contexts per layer. Do you ha do you share the the context across many layers or do you do you use it only once? Um uh so in character AI-like models, uh this number is one. Um we said this is a hundred and twenty eight. Um and uh this is a choice which typically ranges from one uh sorry, this is KV heads. I meant um
KV heads are the heads that are stored in memory, like store the contents of the previous tokens. The Q heads are the um the retrieval heads. They're only used temporarily and they're they're used by the attending tokens. So um in this Autoaggressive context. I've got KV heads associated with all of the context. Yep. And then Q heads associated with this new token here.
But but but this had the one twenty eight.
Oh uh this is Um it it's uh this this number is actually the same for Oh sorry. This d head is the dimension of the vector. Yeah. And number of KV heads is typically in the range of one to eight. So um Like it is totally plausible to get this by s f for example having eight KV heads and and a D head of 128. That gives you exactly this number. Or or you could have like s fewer KV heads but more like that.
Yeah. Um so th this is one way to get there via dense attention. There's also a way to get there via sparse attention where you um increase all of these numbers, but then you have like a one-over sparsity term. So yeah. I mean I I think this number is plausible if if maybe a little bit smaller.
It's funny that they would leak so much information through their API price.
I mean, you are incentivized to price close to your costs because otherwise someone could scoop you.
Maybe we can learn something about the difference in input versus output prices. Yeah. And what that tells us about decode versus prefill in these models. Um and I think Last day I checked it' fifty percent more expensive or something like that.
I I don't remember what I've seen in the past is like three or five times.
That makes more sense. Let's say it's five more times more extensive. Okay. This is the compute to process the next token in decode. Suppose you're doing pre-fill where you're not just processing the most recent token, you're processing all the tokens in parallel. So I wanna say that it would be this times length Cảm ơn các bạn đã theo dõi và hẹn gặp lại.
Prefill. Yeah. If we say like if we can think of decode as being a pass with one and then prefill being a pass with many.
Okay. Yeah, yeah. Um so maybe like prefix? Okay, memory. So you're not storing the KV cache if you're for the tokens that are the prefill tokens.
I think maybe l maybe sort of let's draw actually how prefill shows up here. Um uh if I may clarify. Uh so we do a bit of decode like this. Um We may actually come back and do more pre-fill. Like like if you think this is a chat session, the user sa user says something, the AI generates response, and then the user says something else, and we pre-fill this. So like maybe this is the more common, like this is the general case rather than this.
And in fact, this is like you read a file or something.
Or just like the AI is responding to user input or a tool call or anything that's not.
Yep. Okay, so suppose we're here. So you will need to load Basically, the you will have calculated all of this previously. So just the KV of everything that came before. But what is the memory cost of this? Well
🔇 Silence
memory bandwidth cost of this. If you're doing flash attention,
Yeah, it it's it's basically temporary. It it it it doesn't even go to main memory. Just ignore it.
Okay. So then it would just be everything that came before. So is it not just that then?
Yeah, there's actually no adjustment at all to the memory type.
Great. Oh, it's it's a very trivial change to accommodate. This term is making it five X more expensive. Now, why would that be? Or what does that tell us about What what are we trying to learn here? What does that actually tell us? What what variable does it help us clamp? Um Well, the compute has presumably gotten five like the only thing that could have changed is the compute is five five X more expensive as a result.
So th so yeah, this this is the time for one pass, but actually the amount of tokens is that that much larger. So I guess we want the cost per token effect. or the time by token.
Sorry, I'm not sure I understood. The this is Th this is for processing the next token in prefix.
Uh well actually for processing the entire batch. Um so in this t like at this cost we have processed this many tokens like Len Prefill.
Yeah.
Um I guess pref yeah, like the oh of the paths. Yeah, no not not this prefix, but it's this cost. Okay.
🔇 Silence
We could so th this is five X more expensive. But it's private more sense. Oh, but it's why five weeks more expensive.
So the the result we want to work towards is that prefill is compute limited and decode is Um memory bandwidth limited.
W why don't we do this? Why don't we have why don't we just start it with like a len pass on the x-axis? Yep, yeah. T on the axis.
T we want the cost per token, so it'll be T over some stuff. T over length at the pass. Yeah, that'll be right.
🔇 Silence
Okay.
So...
🔇 Silence
Okay.
It gets me confused about this. Len passes the It it seems like this should be higher when you're doing prefactor.
Prefill has a bigger lengt length pass. Yeah. Right.
But then why is it
Why is it cost higher? Yeah, yeah. Um so I mean we're gonna it's this division by length pass that that actually makes it all uh So
Okay.
This is gonna divide out, this is gonna divide out, but then we're gonna get a div all of this is gonna divide by length of pass, and it's gonna make the memory cost cheaper.
Okay, l yeah, l let me let me think about this then. Okay, so let's do one line for Basically, we'll have four different lines. Um let's do the let's do pre-fill first. And so Actually l'es let's do decode first.
Oh uh so actually I mean the length length of the pass, when it's one, that is decode. When it is bigger, that is prefixed.
Okay, yeah. That makes sense. Okay. Getting back to it. So T compute, if you have um basically it's just this divided by length pass is just this amount. So this actually does not vary based on T, so it'll just be some flat value. Like this. Um and this is
Yeah.
And then th this is like Uh this is
Let's decard it.
Decode, right. Um now TMM, if we have this whole thing divided by length pass, well, it doesn't really matter what's up there. It'll just be something that looks like this.
Right. Yeah.
Say this is T.
Amen.
This is decode again. So as the length of the prefix goes up or pass, your memory bandwidth time declines. And that means that to the extent that you were mem bottlenecked on memory bandwidth before, you can avoid being bottlenecked on memory bandwidth. The fact that they are charging five X less.
Or
pre-fill then decode does suggest that they are bottlenecked on memory bandwidth to quite a degree, such that for them at least, because T is equal equivalent to cost, right? It's a cost of renting a compute. This is actually like this this would be at one and this would be at five. Yeah. So it it is in fact tremendously memory-band with bottleneck. Like that. Exactly. So yeah, let's let me do it this way. Yeah, that's right. Um
🔇 Silence
And then these
This is the gap on decode between the memory and the compute time.
Yeah.
Yeah. Okay, interesting. Another interesting one would be why cash hits are so much cheaper.
Yeah, okay.
I think if I remember correctly, cash it's are like ten X it's more expensive to write to cash. according to the pricing on all these models. But if you do hit a cache, it's 10x cheaper. So what is going on Presumably, this is the cost of keeping something in HBM rather than just evacuating it.
Right. So there's two ways you can produce um tokens, uh or the the KV cache for a token. Um you can just produce it from scratch by computing it from the underlying like token IDs, which are tiny. Um Or you can um previously have produced it and stored it in memory somewhere. So the cost ratio is really talking about the ratio between those two mechanisms of producing it. A cache miss means you've deleted it from all your memories and you have to recompute it on the tokens directly.
In fact, you can maybe even take that a step further and think about which memory tier do you store it in. So you could store it in HPM. There are other slower and cheaper memories than HPM, like DDR on your host or Flash as well. And so one of the things you can do is a is a calculation of um where it makes sense to be in each memory tier. Um and this is related to um how long you're going to store a
So we want to look at the cost of storage in in a few different memory tiers and also the cost of rematerialization. So um uh remat means the cost to rematerial like rebuild all of the KB cache from scratch, having it after you deleted it. So we rematerialize it. Um and so basically this is going to cost the uh length of the context. Um Actually we'll look at c uh cost per token so that we don't need to carry around this length of context. To rematerialize one token of KV cash um
I just need to run I need to run a forward pass on the whole model. And um and then so this is gonna be the compute time. I have to rerun the compute. Um at whatever speed my GPU does it, and then I multiply it by my like GPU Dollars per second. Um
Uh sorry, extremely naive question. Why is there not a quadratic term?
Yeah, so uh there is a quadratic term. Um in it shows up in the compute. Um As an approximation, I chose to remove it. Um the what that I'll I'll just show you sort of quickly what that looks like. It's because so you have the um If you look at the cost per token, um or the number of flops per token, there is the flops that are coming from doing the weight matrix multipliers uh as a function of context lengths. Um
And then there is the number of multipliers that comes from doing the KV cache, which is which goes up linearly with the amount of stuff you attend to. The slope on this is so low that like when you when you draw it like this, it's like it's very well approximated. So like it starts to like you start to notice the effect of the quadratic or the linear term up in the in the millions of tokens or so. So just not super relevant.
So what is the reason that there's no company which has over a million token contact sling?
Um Yeah. So there are two costs of long context. One is the memory bandwidth cost, which we've spent a lot of time analyzing. That's this thing. Um And then the other one is the compute cost. The compute cost is almost always um and sort of actually forced by um
fundamental principles uh to be a much smaller slope than than the memory bandwidth cost. And so the primary thing that limits you to have really l large contexts are memory bandwidth and memory capacity, which is the exactly this effect.
Like. Um and so there's this idea that Dario said on the podcast and others have said, which is we don't need continual learning for BGI in context learning is enough. And if you believe that, then you have to think that we had to get to hundred million token uh a hundred million billion context length.
To have an employee that is the equivalent of working with you for a month. Now, maybe that's no longer true as far as attention or something. Yeah. But um Yeah, if you think that then as a m some ML infra thing would have to change to allow for a hundred million Like the memory bandwidth to allow for a hundred million token context lengths.
I mean, sparse attention gives you a get out for sure because you get this um square root, like you know, gives you a a big improvement. Um But I think it's like if you look at the history of um context lengths of models. From like earlier models like GPT three, maybe to GPT four, I don't remember when the transition happened exactly. Like they shot up from like about eight K to a hundred K, two hundred K.
Um and then for the last year or two, they've all been hovering around there. Um I think that actually indicates that that that's sort of the reasonably balanced uh cost point. And going massively beyond that would be cost prohibitive.
Not because of the compute cost.
Because of the memory bandwidth cost. Yeah. I actually don't see A very good path to solving that. Like the memory m the HPM is where is it's at where it is. Uh it's not getting hugely better.
Why doesn't Sparse Attention solve it?
Sparse attention is a big improvement. Um uh maybe that is priced in already, perhaps. Um uh it's not an infinite improvement because if you go too sparse you lose too much quality. But yeah, I mean the empirical result is that uh the context things haven't been increasing that much. Um uh and and I think it's because there is no solution to the memory wall. Yeah. Like so going too sparse just means like you're attending to a very small subset of the tokens and the quality will get worse.
So what is the cost of uh of these different ways of producing um uh uh r resynthesizing the KV cache. Computing it from scratch is based on my GPU time. I have to do a certain amount amount of multiplies in order to um uh uh of GPU time that I spend in order to produce it. Um Yeah. This really goes as my um I think I had a number here, which was the bytes per token.
Um so I need to I need to have some number of bytes per token. And then I need to store this in the uh HPM. So it's gonna use up some of my HPM capacity. So a way to think of this is that like if I have too many of these things sitting in my HPM, like if I fill up my HPM with just KV caches that I'm not using, I can't use that GPU. And so how do I price that? Maybe I say that The cost of it is proportional to the fraction of the HPM I'm using. So there's also times GPU dollars.
And then let's just do one more memory tier and say something like uh DDR, um store in DDR instead. Um The same kind of thing it goes up for flash and and for DDR. Um I put these in the wrong columns actually. Um I meant to make two columns. The the distinction I want to make is that there is the time to ret uh cost to retrieve.
🔇 Silence
And then there's uh cost Costs to store, um costs to hold uh hold on. Um and so this is like this is a cost per second, whereas this is like an instantaneous cost. Um so rematerialization has a cost to retrieve and has zero cost to store it because we've deleted it. Um this is the one that I put in the wrong location. This is this is actually the cost to s to hold on, so I will rewrite it.
🔇 Silence
Okay. Um so we have this is the uh like if we're just storing it in HPM, it has this sort of cost profile. Um Uh and then if we store in DDR, um it's actually gonna take some time. So it's like we get the same thing here, but it's And so this is um bytes per token. uh the D DR bandwidth. Um uh bandwidth. Uh and then this consumes some amount of the DDR.
And every scale up has D D R and
There's really a deployment question, and so you you can choose that. Um NVIDIA does deploy in this form. Uh it has it has both.
Why isn't the cost to retrieve HBM the memory bandwidth or the bytes divided by memory bandwidth?
Yeah, I mean it depends what what you define a retrieve to be. Here I'm defining retrieve to be um uh move it into HPM so that you can start actually doing inference on it. And so like sort of by definition.
Because if it's already in HBM you can be doing compute while you're getting it from HBM to
Yeah, for example. So these are three things, and I I guess I ordered them wrong. Um in general, if you if you're balancing two costs and you've got different memory uh different tiers in the memory hierarchy, you should expect as as This cost goes up, this cost should go down. Um so you can s kind of see where the zeros are and um like I should have ordered them this one first.
This one second and this one third. So if you're going to hold on to it for for a very short amount of time, then the um all of this is like multiplied by the um hold time.
Yep. And and interestingly, they have different prices to write for and is that you specify this in the API for five minutes versus an hour.
Yeah.
Which suggests that the five minutes is H B and the hour is D.
I think that's a pretty good assumption. It could if you look at the numbers, it might also turn out that it's one tier down and it's TDR versus flash is there.
Okay, interesting. And the price difference I think was i'll look it up okay so the um
Basic
Uh base input tokens is five per million tokens.
Which means remote.
Yeah. That's five. Um
This is five?
Two like retrieve, quote unquote. And then the um to write to um presumably HBM pretty for five minutes is six point two five.
So actually we might actually be able to determine the um which memory tier it is by um by the durations actually. The duration probably tells it to actually
Th the six uh five minutes versus one hour.
Yeah, exactly. I think uh this will probably end up being um It's gonna be the drain time of the memory uh tier that you're in. And so what that means is like uh like given that I'm I know I'm gonna be holding something for five minutes, I would like to
have pick a memory that I can read every five minutes. Like I can read the whole memory once per five minutes ballpark. So that is the drain time of the memory. So if I take the the the like or the storage, storage capacity over storage bandwidth.
And with?
Um I would like this to be like equal to five minutes or something like that. Um and so actually we did this calculation for HPM. For HPM, we know that this number is 20 milliseconds. Um So HPM is much too short, like uh much too small. Um DDR Could be about an order of magnitude or or two off from this. And so this is probably in the order of like actually I think it might even be in the in the seconds, like one to ten seconds.
This is really I don't have these numbers numbers memorized, but generally as you go to slower tiers, uh flash is plausibly in the order of one minute. Um and then like spinning disk, uh which is massively different, I think, is on the order of one hour. So this might actually identify that the tiers are probably flash and spinning disk.
Sorry, wh why why is this the calculations? The storage cap divided by the bandwidth?
So um you you've got a bunch of different memory tiers, like we've listed four of them. Um Uh the your choice like your choice of which memory tier is a like you're you want to minimize the cost. And so you are like what fraction of the device are you using? You're using some fraction of the device. for the holding onto it and then you're using some fraction of the device to retrieve it. Um
And so let's say I'm using like 10% of the device. Um and I want to equalize those two fractions. Uh that that's a sign that I've hit the right um the right thing. So let's say I've got some runtime here, like
I I'm gonna hold on for all of this time. Um uh and then so this is the time hold. Uh and then there's gonna be some amount of time here, which is time retrieve. Uh And I want b I mean basically to equalize the costs, these two costs, um I want the retrieval time to be equal to the hold time. uh times the like fraction of capacity. Um because like this is the the retrieval time. Uh yeah. I mean this is this is how many other things I can hold simultaneously.
Basically just like hey, y you wanna you you want to store things in there for so long such that The amount of time it's in there is kind of the time to get all your things in there and out. Yeah, basically.
I I think that probably indicates that this is th the two tiers of flash and and spinning disc. I'm kind of shocked to see spinning disc being used at all because it's such an old technology, but yeah.
I mean it's also crazy that it's so slow that it takes an hour to load its full capacity into it and then
Like it's a really unattractive technology, but it's useful in some places.
So we're sitting down because I want to ask you some questions that uh I guess don't need to black forward. Um you have this extremely interesting blog post where you talk about
¶ Convergent evolution between neural nets and cryptography
how at a high level, the architecture of different cryptographic graphic protocols looks a lot like neural networks. And there's this conversion evolution where they both need to jumble information across all their inputs for cryptographic protocols. It's to make sure that there's like
Each new input into a hash function will totally scramble what happens. For neural networks, of course, they need to consider inform uh how this piece of information changes what you should make of this other piece of information. I thought that's a extremely interesting point. I guess the uh at a high level the the difference in what they're trying to do. In in some sense they're trying to do the inverse thing. Right. Which is um
Cryptographic protocols are trying to take information which has structure and make it look indistinguishable from randomness. Yeah. And uh neural networks are trying to take things which are look like random. protein sequences, DNA, I garbled text, and extract higher level structure from it. So Uh they have similar high level mechanisms, but they're actually kind of trying to do the opposite things. Um what you make of that.
Yeah. Um so I mean the f like the mixing, like uh I I try to look for other examples where mixing, like scrambling mixing shows up as well. There's actually almost even like a physical example where
Like you're stirring something, you're making a cake and you wanna stir the batter. And like literally the idea, like first stir it this way and then stir it this way is like actually not too bad of an approach. Um but beyond that, like in back to the digital world, um Th there are some differences and the one you talk uh call out is is a pretty strong difference. The way it shows up, um like
What makes neural nets uh like if you just randomly initialize a neural network, actually maybe it's a reasonable cryptography like uh uh uh cipher as well, because like the random initialization is gonna jumble stuff in a complicated way. It may even like do what you want, who knows? Um Uh the thing that makes it interpretable is the gradient descent. So you can differentiate a neural network and get a meaningful derivative. And we do a lot of work to
like not overcomplicate the derivative. So the residual connection keeps it like con contained and simple. Um and the uh and so does like the layer norm uh stuff that we do. Um One of the biggest attacks against uh cryptographic ciphers is also to differentiate the cipher. Um ciphers run in a different number field, they run in um uh
the field of two elements, so just binary, um whereas neural nets run like in theory in the field of real numbers. Um uh and so you have to differentiate with respect to like binary numbers. Um but
You can absolutely differentiate a cipher and this is called differential cryptanalysis. And uh like basically what it says is that if you take a small difference of the input, how like uh it's quite difficult to make uh the difference of the output be small, like oh like Uh the whole job of a of a well-designed cipher is to make the difference in output very large.
I I guess the distinction is that the the optimization goals at that point are about complexifying. They d they don't have the same residual connections or um or like layer norms that that would be.
I mean I I guess a place where the the two merge. is backdoors. Um okay, so with a backdoor n LM, you're trying to hide um What do you consider an input? It's not an input into the forward pass or just an input into the backward pass, but you're trying to hide an input into the backward pass.
Like you're like this is like an adversarial uh uh context. Yeah. So yeah, I mean in fact this is like this is actually a place where you get exactly the um sort of avalanche property that m ciphers have as well. Um Like adversarial attacks on typically like image classification models, right? Are can I find a perturbation of the image that a very, very small perturbation of the image that totally changes the classification, totally changes the output.
Th that is the common case in ciphers, whereas it that's the like undesired case in in neural nets for sure.
Okay, so I was asking you Uh has have neural networks actually been used for cryptography? And um we realize it might be better to just do this on the blackboard. Yeah. Um so I'm curious. Are are they actually being used for cryptography?
Yeah. So using neural nets for c cryptography, well in general cryptography, like creating a new cipher is a very, very dangerous proposition. Like uh almost all of them are broken, like 99% of them are broken. So uh Probably a bad place to start. But the other direction has been very like in in at least one very clear case, quite productive. Um
So there's this construction in sorry a construction that exists in in ciphers and then was imported into neural nets um called a Feistal cipher, feistel network. Um so the idea is that um you you you may have some some some function f. But you like the function because it like does interesting things. Like it it it um it does an MLP, for example, or or it mixes it in an interesting way. Um you'd like to build something out of this that is invertible.
So the construction we're gonna make is gonna actually be a two-input function rather than a one-input function. Um we're gonna apply uh We need to actually remember what x was. So we're gonna stick x over here so that we can uh work backwards. And then we also count drop y. So we're gonna remember y and we're gonna add them together. And so we form this tuple.
So um the the way to invert this, like if you think I have this output and I want to recover X and Y, well I can easily recover X, that's right there. I just read it off. And then to recover Y, I like if this thing was called Z. Um I can I can recover y by z minus f of x, because I've already recovered x. So so that means that this construction is invertible.
This was used in ciphers like a ton. Um still is used. It's one of the main uh mechanisms of constructing ciphers. Often you want ciphers to be invertible, especially the layers of ciphers you want to be invertible, um, because that has better cryptographic properties. This has actually been ported over into um There is a tr uh 2017-18 paper called RevNets, Reversible Network.
And what it does is it actually makes the entire like you can apply it to any network, like a transformer network. You can make I do a forwards pass, but then I can actually run the entire pass backwards as well. Um so the whole neural network is invertible. Um with exactly this construction. And so this paper reversible networks, um, like applied to some layer like a transformer layer, for example, we've got this function f, which is uh our transformer layer.
Now normally we would have um just an input and then a residual connection coming out, um and it gets added like this. Um over here. Um but now uh the variation of this is going to be we've got two inputs x and y. Um so we've got x and we're going to do a Why inputs? Um x goes through the function, gets added to y.
🔇 Silence
And then this becomes the new X, the output X. And then this X. So um really what this is doing, this is like this is actually sort of doing if you think of two layers Uh back. This is actually the thing you mentioned before. It's actually doing the residual connection from two layers back. Like this Y came from the previous layer and was the residual connection there. But because of this construction, the whole thing is invertible.
Why do I care? What does invertible matter for? The big thing that it can be interesting for is for training. Um if I think of a forward passive training, um so I will ru let's say I have four layers, I run them in the zero, one, two, three order, um, I have to write all of the um activations to HBM. Um and so I get an HPM footprint um here that is kind of like linear linear in uh number of layers. Yep.
Um so this this actually can be uh the largest memory footprint during training. Um and so this is normal training and then and then I run the backwards pass and I read it kind of in reverse. Like I I run them and sort of forward pass goes forwards, backward pass goes backwards, and I have to read them back out.
The idea of this RevNets paper is that because it's invertible, um, I don't need to store this at all. I can completely rematerialize it when I'm running my backwards pass. So I run my forwards pass, and then when I'm running my backwards pass. I'm simultaneously in lockstep. Undoing all of the forwards pass steps that I did in order to have the activations that I need here. So this ends up being a memory saving, which is a nice idea.
Interesting. It's bending more compute to save memory. That's right. Interesting. Actually it's kind of the opposite of what you're doing with the KV cash. The KV cash. Yeah.
Yeah. Yeah. Uh spending more memory to save computers generally profitable given where
Interesting. Cool. Uh that was super fun. Right. Thank you so much for doing it. I uh I feel like it really vindicated the vision behind the studio and the and the blackboard. Cool. Thanks so much for doing it.
