EE/Digital Design

[Degital Design] 2D Systolic Array for matrix multiplication

아이스얼그레이 2022. 9. 30. 17:44
Matrix Multiplication(이하 MM)을 가속하기 위한 Systolic array architecture에 대해서 알아보겠습니다.

 

http://ashanpeiris.blogspot.com/2015/08/digital-design-of-systolic-array.html

 

Digital design of systolic array architecture for matrix multiplication

Systolic architecture consists of an array of processing elements, where data flows between neighboring elements, synchronously , from di...

ashanpeiris.blogspot.com

위 링크의 게시물을 참고해 작성한 글입니다,

 

 

다음과 같은 2 x 2 정방 행렬 A, B의 간단한 MM을 고려해봅시다.

 

위 행렬곱을 손으로 계산하기 위해서는

 

 

행렬 C의 좌표별로 이러한 내적 연산을 4번 수행해야 합니다.

 

이때 발생하는 곱셈, 덧셈 연산의 횟수를 세어보면, 곱셈 8번, 덧셈 4번입니다. 이는 사람이 손으로 계산하기에 크게 문제가 되지는 않습니다. 하지면 HW로 이를 계산하려면 어떻게 될까요?

 

HW에서도 곱셈, 덧셈은 결국 수행해야 하니 제쳐두고, 진짜 문제는 memory access에서 발생합니다. Computer architecture 수업을 들어보신 분들은 아시겠지만, memory access는 cost가 상당히 큰 operation입니다. 그런데 저렇게 간단한 MM을 위해 매번 data를 main memory에서 가져온다? 이건 상당히 비효율적입니다.


이러한 문제를 해결하기 위한 parallele architecture가 systolic array입니다. 본 게시글에서는 systolic array에 필요한 buffer에 대해서는 다루지 않았습니다.

 

systolic array의 구조는 다음과 같습니다.

이때 사진의 각각의 네모를 Processing Element(이하 PE)라고 하고 각 PE의 내부에는 MAC 연산을 수행할 수 있는 다음과 같은 module이 들어있습니다.

 

이 PE는 오직 MAC 연산만 수행합니다. MM을 위해 특화된 architecture이기 때문에 이러한 형태로 설계되었습니다. in_a, in_b가 input으로 들어오면, out_c에 out_a * out_b를 더함과 동시에 out_a, out_b를 통해 in_a, in_b를 다음 PE로 전달해줍니다. PE array를 코드로 작성해보면 아시겠지만 in_a, in_b를 다음 PE로 넘겨주는 동작 자체가 data reuse입니다.

 

MM에서 결과 행렬의 n번째 행은 행렬 A의 n번째 행을 계속 가져다 쓰기 때문에, memory에서 한 번 불러온 뒤 계속 써줘야 합니다. 그리고 결과 행렬의 n번째 열은 행렬 B의 n번째 열을 계속 가져다 쓰고, 그 data를 계속 재사용해줘야 합니다.

 

verilog로 작성하면 다음과 같습니다.

module pe #(
    parameter data_size = 8
) (
    input                       clk,
    input                       reset,
    input           [data_size-1:0]     		in_a,
    input           [data_size-1:0]     		in_b,

    output  reg     [data_size-1:0]         	out_a,
    output  reg     [data_size-1:0]         	out_b,
    output  reg     [2 * data_size - 1:0]       out_c
    );

    always @(posedge clk) begin
        if(reset) begin // reset 들어오면 a, b, c initialize
            out_a = 0;
            out_b = 0;
            out_c = 0;
        end
        else begin // else, clk에 동기화되어서 MAC 연산 및 다음 pe로 넘겨주는 동작
            out_c <= out_c + in_a * in_b;
            out_a <= in_a;
            out_b <= in_b;
        end
    end

endmodule

MAC 연산과 in_a, in_b를 다음 PE로 넘겨주는 코드를 non-blocking statement로 썼는데, 이렇게 하지 않으면 연산 간에 약간의 delay가 생겨서 in_a, in_b를 넘겨주고 다음 PE MAC 연산을 하는 타이밍이 어긋나서 제대로 된 결과가 안 나옵니다.(이는 제가 사용한 Sysnopsys VCS 기준이며 다른 simulator에서는 정상 동작할 수 도 있습니다.)

 

아무튼 PE는 이와 같이 동작하고 다음은 PE를 도식화한 그림입니다.

두 개의 입력 a, b를 받고 그것을 그대로 output a, b로 내보냅니다. 이때 diagram에서 왼쪽, 위쪽에서 들어오는 signal이 input이고, 오른쪽, 아래쪽으로 나가는 signal이 output입니다. 그리고 동시에 in_a * in_b를 c라는 변수에 누적시킵니다.

MM이 곱한거 더하고 곱한거 더하고 이런 연산이 계속되기 떄문에 MAC으로 빠르게 처리할 수 있습니다.


다음은 PE로 network를 만들었을 때 어떻게 동작하는지 보겠습니다.

Systolic array

위 그림과 같이 PE1의 output이 PE2, PE4로 전달되고 PE2의 output에 PE3, PE5로 전달되고.... 이렇게 연결되어서 연산에 사용한 data를 다음 PE로 넘겨주게 됩니다. 화살표 방향에 나타난 것 처럼 뒤쪽에 있는 PE에서 앞쪽에 있는 PE로 data를 넘겨줄 수는 없습니다. MM의 data reuse pattern상 그럴 필요도 없습니다.

 

이를 verilog로 기술하려면 각 PE를 연결해주는 wire를 선언해야줘야 합니다. 이는 설계자의 의도대로 가독성있게 정해주면 됩니다. 저는 data a인지 data b인지 구분하고 그 뒤에 시작 index, 끝 index를 써줬습니다.

 

예를 들어 PE1에서 PE2로 넘어가는 data a라면 a12 이런 식으로 말이죠. 그리고 당연히 clk signal에 동기화되어야 히기 때문에 각 PE module에 clk과 reset이 들어가야 합니다.

 

verilog로 작성하면 다음과 같습니다.

module sys_array #(
    parameter data_size = 8
) (
    input       clk,
    input       reset,
    input   [data_size-1:0]   	a1, a2, a3, b1, b2, b3,
    output  [2*data_size-1:0]   c1, c2, c3, c4, c5, c6, c7, c8, c9
    );
 
    wire [data_size-1:0] a12, a23, a45, a56, a78, a89;
    wire [data_size-1:0] b14, b25, b36, b47, b58, b69;
    
    pe pe1 (.clk(clk), .reset(reset), .in_a(a1), .in_b(b1), .out_a(a12), .out_b(b14), .out_c(c1));
    pe pe2 (.clk(clk), .reset(reset), .in_a(a12), .in_b(b2), .out_a(a23), .out_b(b25), .out_c(c2));
    pe pe3 (.clk(clk), .reset(reset), .in_a(a23), .in_b(b3), .out_a(), .out_b(b36), .out_c(c3));
    pe pe4 (.clk(clk), .reset(reset), .in_a(a2), .in_b(b14), .out_a(a45), .out_b(b47), .out_c(c4));
    pe pe5 (.clk(clk), .reset(reset), .in_a(a45), .in_b(b25), .out_a(a56), .out_b(b58), .out_c(c5));
    pe pe6 (.clk(clk), .reset(reset), .in_a(a56), .in_b(b36), .out_a(), .out_b(b69), .out_c(c6));
    pe pe7 (.clk(clk), .reset(reset), .in_a(a3), .in_b(b47), .out_a(a78), .out_b(), .out_c(c7));
    pe pe8 (.clk(clk), .reset(reset), .in_a(a78), .in_b(b58), .out_a(a89), .out_b(), .out_c(c8));
    pe pe9 (.clk(clk), .reset(reset), .in_a(a89), .in_b(b69), .out_a(), .out_b(), .out_c(c9));
    
endmodule

생각보다 간단하죠? PE를 연결하는 wire 선언과, 그게 어디서 어디로 들어가는지만 잘 구분해주면 됩니다.

 

그리고 c1, c2, c3,... ,c8, c9는 결과 행렬 C의 각 element입니다. 그러니까 

 

c1  c2  c3

c4  c5  c6

c7  c8  c9

이런 형식의 행렬라는 것이죠. 그리고 각 element와 같은 index를 가지는 PE가 각 element를 연산합니다.