ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Fused 커널은 왜 3~4배 빠른가 — GPU 메모리 계층과 Marlin의 비밀
    IT 2026. 5. 5. 23:40
    Fused 커널은 왜 3~4배 빠른가 — GPU 메모리 계층과 Marlin의 비밀

    들어가며 — "단계를 합쳤더니 왜 빨라지는가"

    직전 글에서 Q4 양자화는 곱셈 직전에 dequantize를 거쳐야 한다고 했다. 그리고 그 변환 비용을 거의 0으로 만드는 도구가 Marlin 같은 fused 커널이라고 짧게 언급하고 넘어갔다.

    그런데 잘 생각해 보면 이상한 구석이 있다. 똑같이 dequantize 하고, 똑같이 행렬곱을 하는데, 두 단계를 그냥 한 함수에 합쳤다고 어떻게 3~4배가 빨라지는가? 합쳐도 일의 양은 똑같지 않은가?

    답은 GPU 메모리 계층에 있다. 단순 분리 구현은 중간 결과를 한 번 메모리에 썼다가 다시 읽는다. 그 왕복이 실제 연산보다 훨씬 비싸다. Fused 커널은 그 왕복을 없앤다 — 중간 결과를 GPU 레지스터 안에서 곧장 다음 단계로 흘려보낸다.

    이 글은 그 메커니즘을 GPU 메모리 계층, Marlin, FlashAttention 사례를 통해 SVG 다이어그램과 함께 풀어쓴다.


    1. 먼저 GPU 메모리 계층을 알아야 한다

    GPU에는 여러 종류의 메모리가 계층으로 쌓여 있다. CPU의 RAM·L3·L2·L1 캐시·레지스터 구조와 비슷하지만, GPU는 그 격차가 훨씬 극단적이다.

    diagram

    한 단계 위로 올라갈 때마다 대역폭은 약 3~10배, 지연시간은 4~10배 빨라진다. 가장 위(Register)와 가장 아래(HBM)의 격차는 대략 300배다.

    그리고 가장 중요한 사실. 텐서코어와 CUDA 코어는 오직 레지스터에서만 데이터를 읽어 곱셈한다. HBM에 있는 데이터를 직접 곱하지 못한다 — 반드시 레지스터까지 끌어 올라온 뒤에야 연산이 일어난다.


    2. 단순 분리 구현은 왜 비효율적인가

    이제 dequantize와 GEMM을 단순히 두 단계로 분리한 구현을 따라가 보자.

    diagram

    핵심은 "두 커널" 구조 자체가 강요하는 비용이다. CUDA에서 별개의 커널은 별개의 시작·종료 시점을 가진다. 커널 ①이 끝난 시점에 만들어진 중간 결과는 SM 안에 보관할 방법이 없다 — 다음 커널 ②가 시작될 때 어떤 SM이 어떤 데이터를 다룰지 모르기 때문이다. 그래서 무조건 HBM에 한번 내려보내야 한다.

    여기서 "어차피 곱셈은 똑같이 한 번 하는데"라는 직관이 깨진다. 곱셈 자체가 아니라 그 곱셈을 위해 데이터를 끌어다 놓는 비용이 압도적이라는 게 GPU 컴퓨팅의 현실이다.


    3. Fused 커널은 어떻게 합치는가

    Fused 커널의 핵심 아이디어는 한 줄로 요약된다.

    "여러 단계를 하나의 커널로 합쳐, 중간 결과가 HBM에 절대 내려가지 않게 한다."
    중간 결과는 SM 안의 레지스터·shared memory에서만 살다가, 최종 결과만 HBM에 내려간다.

    diagram

    그림으로 보면 단순하지만, 실제로 이걸 가능하게 만든 게 결코 단순하지 않다. 세 가지 기술적 트릭이 들어 있다.

    트릭 1: 타일링 (Tiling)

    전체 가중치 행렬은 HBM에 있다. 그걸 작은 타일(예: 128×128)로 잘라서 SM 한 개가 한 타일을 책임진다. 한 타일은 SM의 shared memory(228KB)에 들어갈 수 있는 크기로 잡는다.

    트릭 2: 레지스터 안에서의 즉석 dequantize

    Q4 가중치는 shared memory에서 레지스터로 올라온 직후, 같은 스레드 안에서 dequantize 된다. 레지스터 ↔ ALU 통신은 거의 0 지연시간이라, 이 변환은 텐서코어가 곱셈하는 동안의 짧은 틈에 해치울 수 있다.

    트릭 3: 비동기 파이프라인

    HBM에서 다음 타일을 읽어 오는 동안, 텐서코어는 이미 레지스터에 있는 현재 타일을 곱셈한다. 두 작업이 겹쳐서 실행되니, HBM 읽기 시간이 곱셈 시간 뒤에 가려진다.

    diagram

    이 세 트릭의 결과로, fused 커널의 실제 시간은 "가장 비싼 한 가지"로 결정된다 — HBM 읽기와 텐서코어 곱셈 중 더 오래 걸리는 쪽. 단순 분리 구현이 그 둘을 합친 시간이 걸렸던 것과 대조된다.


    4. 두 종류 코어를 동시에 — fused가 노는 회로를 깨우는 법

    여기서 자연스러운 의문이 하나 생긴다. "잠깐, dequantize는 누가 풀어주는가? 텐서코어가 행렬곱만 한다면, dequantize는 어떤 회로가 처리하는가?"

    답은 CUDA 코어다. 그리고 fused 커널의 진짜 위력은 여기서부터 시작된다 — 한 SM 안에서 CUDA 코어와 텐서코어가 동시에 일을 한다.

    SM 안에는 여러 종류 회로가 같이 있다

    하나의 SM(Streaming Multiprocessor)은 여러 종류의 연산 유닛을 한 칩 안에 담고 있다. 마치 한 공장 라인에 다른 기계가 동시에 돌아가는 것과 같다.

    diagram

    핵심은 이 회로들이 물리적으로 분리돼 있다는 점이다. 그래서 같은 사이클에 서로 다른 명령을 동시에 처리할 수 있다 — CPU의 하이퍼스레딩이나 슈퍼스칼라와 비슷한 개념이다.

    fused 안에서 누가 뭘 하는가

    Marlin 같은 fused 커널이 한 GEMM 타일을 처리하는 동안, 위 유닛들의 분업은 다음과 같다.

    단계 담당 유닛 하는 일
    1. HBM에서 Q4 가중치 가져오기 Load/Store cp.async로 shared memory에 복사
    2. shared mem → 레지스터 Load/Store 워프 스레드들이 레지스터로 끌어 올림
    3. 4비트 unpack + scale 곱 + min 합 CUDA 코어 레지스터 안의 정수를 FP16으로 풀어줌
    4. FP16 결과로 행렬곱 누적 텐서 코어 mma.sync 발행
    5. 누적 결과를 HBM에 출력 Load/Store 최종 결과만 한 번 씀

    여기서 결정적인 사실은 3번과 4번이 시간상 겹친다는 것이다.

    두 코어가 어떻게 동시에 돌아가는가

    워프 스케줄러는 한 SM당 보통 32~64개 워프를 동시에 보유한다. 매 사이클마다 어떤 워프의 어떤 명령을 어느 유닛에 발행할지 정한다. 한 워프가 텐서 코어에 mma.sync를 발행하면 그 명령은 백그라운드에서 여러 사이클 동안 실행된다. 그 사이 같은 워프 또는 다른 워프가 CUDA 코어에 dequantize 명령을 발행한다.

    diagram

    그림에서 빨간 점선 영역을 보면, 같은 시간대에 Load/Store 유닛이 다음 타일을 가져오고, CUDA 코어가 직전 타일을 풀어주고, 텐서 코어가 그 직전 타일의 GEMM을 돌리고 있다. 세 유닛이 모두 일하고 있다.

    그래서 dequantize 비용이 "공짜"가 된다

    텐서 코어는 워낙 무거운 일을 한다 — 한 번의 mma.sync가 수십 사이클에 걸쳐 진행된다. 그 시간 동안 CUDA 코어는 다음 타일의 dequantize를 끝내버린다. 결과적으로 dequantize에 걸린 사이클이 텐서 코어 GEMM 시간 뒤에 가려진다 — 전체 시간으로 보면 dequantize 비용이 거의 0인 것처럼 보인다.

    이게 단순 분리 구현과의 결정적 차이다. 단순 분리에서는 dequantize 커널이 끝나야 GEMM 커널이 시작된다 — 두 작업이 직렬로 쌓인다. fused에서는 두 작업이 같은 SM의 다른 회로에서 동시에 돌아간다.

    그래서 fused 커널의 효율을 한 줄로 다시 정리하면 이렇다.

    Fused는 메모리 왕복을 없앨 뿐 아니라, 평소엔 놀고 있던 회로(CUDA 코어)를 텐서 코어 시간에 끼워 넣어 가동률을 끌어올린다.
    "한 커널에 합쳤다"의 진짜 의미는 "한 SM 안의 모든 회로가 동시에 일하게 만들었다"이다.


    5. 사례: Marlin — Q4 dequant + GEMM fused

    Marlin은 IST Austria가 2024년에 발표한 fused 커널로, 위에서 설명한 트릭을 모두 결합했다. Q4 가중치 + FP16 활성값 → FP16 결과를 한 커널로 처리한다.

    실측 효과를 단순 분리 대비 그래프로 그리면 이렇다.

    diagram

    여기서 흥미로운 사실이 보인다. 단순 dequant+GEMM 구현은 오히려 cuBLAS FP16(원본)보다 느리다. Q4로 메모리는 절감했지만, 중간 결과 HBM 왕복이 그 절감을 다 까먹고 거기에 dequantize 비용까지 추가됐다. "양자화 = 빠름"이 아니다. "양자화 + 똑똑한 커널 = 빠름"이다.

    Marlin은 cuBLAS FP16 대비 약 3.5배. 같은 Q4 가중치를 단순 구현 대비로는 약 5배. 같은 데이터, 같은 곱셈인데 단지 "두 커널을 한 커널로 합쳤다"는 이유로.


    6. 그럼 왜 모든 걸 fused로 만들지 않는가

    여기까지 읽으면 자연스러운 질문이 나온다. "좋아 보이는데, 왜 모든 GPU 연산을 fused로 안 만드는가?"

    세 가지 한계가 있다.

    diagram

    ① 수작업 비용

    Fused 커널은 일반 CUDA 코드로는 못 짠다. 레지스터 할당, shared memory 뱅크 충돌, 명령어 스케줄링까지 직접 다뤄야 한다. PTX(중간 표현)나 SASS(어셈블리)를 직접 쓰는 일도 흔하다. 그래서 (Q4+FP16, Q4+FP8, AWQ+FP16, GPTQ+FP8...) 조합마다 별도 커널을 누가 손으로 짜야 하고, 그 일을 할 수 있는 사람이 세계에 많지 않다.

    ② 회로 처리량 상한

    Fused는 메모리 왕복을 없앴을 뿐, 곱셈 자체는 결국 FP16/FP8 텐서코어가 한다. 그래서 처리량 상한은 그 회로의 한계와 같다. "FP8 native보다 빠른 Q4 fused는 존재하지 않는다." Marlin이 화려해 보여도, 같은 모델의 FP8 native 체크포인트가 있다면 그게 일반적으로 더 빠르다.

    ③ SM 자원 압박

    한 커널에 너무 많은 단계를 합치면 레지스터 사용량이 늘어난다. 레지스터가 부족하면 SM당 동시에 돌릴 수 있는 스레드 수(occupancy)가 줄어, 메모리 지연시간을 가릴 여지가 사라진다. 그래서 fused에도 "최적의 합치기 단위"가 있고, 무작정 더 합치면 오히려 느려진다.


    7. 효과 정리 — 무엇이 바뀌었는가

    Fused 커널이 바꾼 것을 표로 정리하면 이렇다.

    관점 단순 분리 Fused 커널
    중간 결과 위치 HBM (느림) 레지스터 (300배 빠름)
    HBM 트래픽 2~4배 (중간 결과 왕복) 최소 (입출력만)
    커널 시작 비용 여러 번 1번
    비동기 파이프라이닝 커널 사이는 직렬 커널 안에서 자유롭게 겹침
    코드 작성 비용 낮음 (조합 자유) 매우 높음 (저수준 튜닝)
    대표 사례 일반 PyTorch 코드 Marlin, FlashAttention, xFormers

    실측 효과를 한 줄로 줄이면 "GPU 연산기를 놀리지 않게 한다"이다. 단순 분리에서는 연산기가 메모리 트래픽을 기다리며 80%를 놀고 있었다면, fused에서는 그 비율이 30~40%까지 떨어진다. 같은 회로, 같은 연산을 하는데 일하는 시간이 늘어나니 throughput이 늘어난다.


    8. 정리 — Fused는 "메모리 왕복을 없애는 기술"

    한 줄로 줄이면 이렇다.

    Fused 커널 = "여러 단계를 하나의 GPU 커널로 합쳐, 중간 결과가 HBM 대신 SM 내부 레지스터에서만 살게 하는 기술"

    왜 빨라지는지 정리하면 세 가지 원인이 누적된다.

    1. HBM 왕복 제거: 중간 결과를 메모리에 안 내림 — 가장 큰 이득.
    2. 비동기 파이프라이닝: 메모리 읽기와 텐서코어 곱셈을 시간상 겹침 — 한쪽이 다른 쪽 시간에 가려짐.
    3. 커널 시작 오버헤드 제거: 여러 번 커널을 호출하지 않음 — 작은 단위 작업에서 의미 있음.

    한계도 분명하다. 곱셈 자체의 처리량 상한은 못 넘는다. Marlin이 cuBLAS FP16 대비 3~4배 빠른 건 Q4 가중치의 메모리 절감 + fused의 왕복 제거를 합쳐서 그렇다. 같은 조건에서 FP8 native 체크포인트가 있다면 그게 일반적으로 더 빠르다.

    그래서 Fused 커널의 진짜 위치는 "양자화의 약점을 없애주는 보조 기술"이다. Q4 가중치는 메모리는 좋지만 dequantize가 비싸다 — 그 비용을 fused가 0에 가깝게 만들어준다. 이 둘이 합쳐졌을 때 "양자화의 진짜 가치"가 드러난다.

    같은 사고방식이 normalization(fused LayerNorm), activation(fused GELU) 같은 다른 GPU 연산에도 그대로 적용된다. "GPU 연산이 빨라지는 길은 곱셈 회로를 늘리는 게 아니라, 곱셈 회로가 놀지 않게 데이터를 흘려주는 데 있다" — 이게 Marlin이 가르치는 교훈이다.


    이 글은 생성형 AI의 도움을 받아 작성되었습니다. 원본 자료를 기반으로 AI가 초안을 생성하고, 작성자가 검토·편집하였습니다.

Designed by Tistory.