EE/AI Silicon System

[AI Silicon System] Project 1: Systolic array

아이스얼그레이 2023. 5. 3. 15:41

[AI Silicon System] 게시물은 KAIST 김주영 교수님의 강의자료와 Project document를 참고하여 작성하였습니다.


첫 번째 Project입니다. 16 X 16 MAC array를 설계하면 됩니다. 이 단계에서는 Memory control은 고려하지 않아도 되고, tb에서 MAC array에 맞게 data를 feeding 해줍니다.

1. Objective

  • Implement a systolic MAC array datapath
  • Compute matrix multiplication using MAC arry

2. Specification

  • Input feature map bit-width: 16 bit
  • Weight bit-width: 8 bit
  • Output feature map bit-width: 32 bit
  • MAC array size (physical): 16 X 16
  • Data dimension
    • Input feature map matrix: 16 X 1024
    • Weight matrix: 16 X 16
    • Output feature map matrix: 16 X 1024

Input data 한 개가 16bit, Weight 한 개가 8bit, Output data 한 개가 32bit 입니다.

그리고 MAC array는 16 X 16의 configuration을 가지며, 따라서 총 256개의 MAC unit이 사용됩니다.

3. Overall design

  • Systolic MAC array

전반적인 architecture는 위와 같습니다. Project 1에서 설계할 부분은 16 X 16 MAC Array이며, 나머지 module은 제공된 testbench에 구현된 equivalent 한 동작으로 대체됩니다. Project를 진행하면서 testbench에서 대체되는 동작들을 직접 설계하게 됩니다.

  • Data flow

딱 위 그림이 Systolic array의 정석적은 Data flow를 나타내는 그림입니다. Weight stationary 방식을 따르니, 먼저 weight matrix를 MAC array의 각 MAC unit에 저장합니다. 이때 오른쪽의 MAC 내부 microarchitecture 처럼 weight를 위한 register에 저장됩니다.

그리고 Input matrix를 흘려주는데, 그림처럼 diagonal 한 순서로 Input matrix를 넣어줘야 합니다. systolic array는 dataflow processor(ISA가 필요 없고 data만 넣어주면 정해진 clock cycle 뒤에 연산이 완료됨!)이기 때문에, 설계한 architecture에 맞는 MM을 위한 data만 정확하게 넣어주면 연산이 알아서 진행됩니다.

4. Timing diagram

딱 이 timing diagram이 나오도록 설계를 해야 합니다!! w_data_in을 넣어주고 그 뒤로 ifmap_data가 쭉쭉쭉 들어가는 그런 형상입니다. 그리고 16 clock cycle이 지나면 0번째 column부터 ofmap_data_out이 나오기 시작해서, ofmap_data가 쭉쭉쭉 나옵니다. 많이 축약해서 설명한 거긴 한데, 설계해 보면 이해됩니다:)

5. System verilog implementation

이제 .sv 로 설계를 해봅시다. Systolic array는 HW 구조가 상당히 간단해서 MAC unit을 잘 짜놓으면 bottom-up 설계를 쉽게 할 수 있어서, bottom-up 방식으로 설계했습니다.

5.1 MAC.sv

`timescale 1 ns / 1 ps

module MAC
#(
    parameter IFMAP_BITWIDTH                                            = 16,
    parameter W_BITWIDTH                                                = 8,
    parameter PSUM_BITWIDTH                                             = 32
)
(
    input  logic                                                        clk,
    input  logic                                                        rstn,

    input  logic                                                        ifmap_enable_in,
    input  logic signed [IFMAP_BITWIDTH-1:0]                            ifmap_data_in,

    input  logic                                                        w_enable_in,
    input  logic signed [W_BITWIDTH-1:0]                                w_data_in,

    input  logic                                                        psum_enable_in,
    input  logic signed [PSUM_BITWIDTH-1:0]                             psum_data_in,

    output logic                                                        ifmap_valid_out,
    output logic signed [IFMAP_BITWIDTH-1:0]                            ifmap_data_out,

    output logic signed [W_BITWIDTH-1:0]                                w_data_out,
    
    output logic                                                        psum_valid_out,
    output logic signed [PSUM_BITWIDTH-1:0]                             psum_data_out
    
);

    always @(posedge clk) begin
        // wegiht prefetch logic
        if (!rstn) begin
            w_data_out          <=      {(W_BITWIDTH){1'b0}};
        end
        else if (w_enable_in) begin
            w_data_out          <=      w_data_in;
        end
        else begin
            w_data_out          <=      w_data_out;
        end

        // ifmap forwarding logic
        if (!rstn) begin
            ifmap_data_out      <=      {(IFMAP_BITWIDTH){1'b0}};
            ifmap_valid_out     <=      0;
        end
        else if (ifmap_enable_in)begin
            ifmap_data_out      <=      ifmap_data_in;
            ifmap_valid_out     <=      ifmap_enable_in;
        end
        else begin
            ifmap_data_out      <=      ifmap_data_out;
            ifmap_valid_out     <=      0;
        end

        // psum MAC operation & psum forwarding logic
        if (!rstn) begin
            psum_data_out       <=      {(PSUM_BITWIDTH){1'b0}};
            psum_valid_out      <=      0;
        end
        else if (ifmap_valid_out && psum_enable_in) begin
                psum_data_out        =      psum_data_in + ifmap_data_out * w_data_out;
                // sign extenstion
                psum_data_out        =      {{(W_BITWIDTH){psum_data_out[IFMAP_BITWIDTH+W_BITWIDTH-1]}}, psum_data_out};
                psum_valid_out       =      psum_enable_in;
        end
        else begin
            psum_data_out       <=      psum_data_out;
            psum_valid_out      <=      psum_valid_out;
        end 
    end
    
endmodule

음 어디서부터 설명을 해야 할까요.. weight_data, ifmap_data (첫 번째, 두 번째 if-else if statement)는 간단한 forwarding logic이라서 동작이 간단합니다. enable signal(w_enable_in, ifmap_enable_in)이 들어오면 register에 저장되어 있던 직적 timing의 data를 다음 PE로 넘겨주도록 동작합니다.

psum(partial sum)을 계산하는 부분이 MAC에서 핵심이 되는 부분입니다. MM 연산에서 결과 행렬의 한 element (Vector-Vector product)의 partial sum이 여러 MAC unit을 거치면서 연산이 됩니다. psum_data_in에 ifmap_data_in, w_data_in이 아닌 ifmap_data_out, w_data_out을 곱해서 더해주는 이유는, *_in이 들어오고 나서 register에 저장된 후에 다음 clock cycle에 계산이 가능하기 때문입니다.

psum computation 부분에서만 non-blocking "<=" 이 아닌 blocking "=" 을 사용했습니다. psum_ data_out을 연산하고 sign을 고려해서 sign extension을 해주는 코드를 추가했습니다. 이때 non-blocking을 쓰면 psum_data_out의 value가 충돌이 생기기 때문에, 앞 code가 끝나야 sequential 하게 code를 수행하는 blocking을 사용해서 그 문제를 피했습니다.

5.2 MAC array

`timescale 1 ns / 1 ps

module MacArray
#(
    parameter MAC_ROW                                                       = 16,
    parameter MAC_COL                                                       = 16,
    parameter IFMAP_BITWIDTH                                                = 16,
    parameter W_BITWIDTH                                                    = 8,
    parameter OFMAP_BITWIDTH                                                = 32
)
(
    input  logic                                                            clk,
    input  logic                                                            rstn,

    input  logic                                                            w_prefetch_in,      // Not used in this module

    input  logic                                                            w_enable_in,
    input  logic signed [MAC_COL-1:0][W_BITWIDTH-1:0]                       w_data_in,

    input  logic                                                            ifmap_start_in,     // Not used in this module

    input  logic        [MAC_ROW-1:0]                                       ifmap_enable_in,
    input  logic signed [MAC_ROW-1:0][IFMAP_BITWIDTH-1:0]                   ifmap_data_in,

    output logic        [MAC_COL-1:0]                                       ofmap_valid_out,
    output logic signed [MAC_COL-1:0][OFMAP_BITWIDTH-1:0]                   ofmap_data_out
);
    // your code here

    // Below are wire for PE interconnection
    logic [MAC_ROW-1:0]                     ifmap_valid [MAC_COL:0];
    logic [MAC_ROW-1:0][IFMAP_BITWIDTH-1:0] ifmap_data  [MAC_COL:0];
    
    logic [MAC_COL-1:0][W_BITWIDTH-1:0]     w_data      [MAC_ROW:0];

    logic [MAC_COL-1:0]                     psum_valid  [MAC_ROW:0];
    logic [MAC_COL-1:0][OFMAP_BITWIDTH-1:0] psum_data   [MAC_ROW:0];


    assign                                  ifmap_valid [0]     =   ifmap_enable_in;
    assign                                  ifmap_data  [0]     =   ifmap_data_in;

    assign                                  w_data      [0]     =   w_data_in;

    // to initiate psum accumulation with traversing PEs
    assign                                  psum_valid  [0]     =   {(MAC_COL){1'b1}};
    assign                                  psum_data   [0]     =   {(MAC_COL * W_BITWIDTH){1'b0}};

    // Not [MAC_ROW-1]!! because these are wires from last row of PE array 
    assign                                  ofmap_valid_out     =   psum_valid  [MAC_ROW];
    assign                                  ofmap_data_out      =   psum_data   [MAC_ROW];

    // i for row, j for col
    genvar i, j;
    generate
        for (i = 0; i < MAC_ROW; i++) begin: row_iter
            for (j = 0; j < MAC_COL; j++) begin: col_iter
                MAC
                #(
                    .IFMAP_BITWIDTH                         ( IFMAP_BITWIDTH            ),
                    .W_BITWIDTH                             ( W_BITWIDTH                ),
                    .PSUM_BITWIDTH                          ( OFMAP_BITWIDTH            )   
                )
                MACunit
                // vertically forwarding    -> mapping via [i][j]
                // horizontally forwarding  -> mapping via [j][i]
                (
                    .clk                                    ( clk                       ),
                    .rstn                                   ( rstn                      ),
                    
                    .ifmap_enable_in                        ( ifmap_valid   [j][i]      ),
                    .ifmap_data_in                          ( ifmap_data    [j][i]      ),

                    .w_enable_in                            ( w_enable_in               ),
                    .w_data_in                              ( w_data        [i][j]      ),

                    .psum_enable_in                         ( psum_valid    [i][j]      ),
                    .psum_data_in                           ( psum_data     [i][j]      ),

                    .ifmap_valid_out                        ( ifmap_valid   [j+1][i]    ),
                    .ifmap_data_out                         ( ifmap_data    [j+1][i]    ),
                    
                    .w_data_out                             ( w_data        [i+1][j]    ),

                    .psum_valid_out                         ( psum_valid    [i+1][j]    ),
                    .psum_data_out                          ( psum_data     [i+1][j]    )        
                );
            end
        end
    endgenerate

endmodule

MAC array가 뭔가 좀 더 복잡해 보일 수도 있는데, 생각보다 간단합니다.

HW에서는 왼쪽 그림에 나와있는 MAC을 연결하는 선들(wire라고 합니다.)을 다 선언하고 연결 관계(interconnection)를 기술해 줘야 합니다. 이때 MAC array는 규칙적인 모양의 wire 형태를 가지기 때문에 matrix 형태로 선언하면 generate 문으로 handling 하기 좋습니다.

그래서 ifmap_valid, ifmap_data, w_data, psum_valid, psum_data를 MAC_ROW X MAC_COL 혹은 MAC_COL X MAC_ROW configuration으로 선언했습니다. 그 후 generate 문 내부에서 MAC inteconnection에 맞도록 index를 지정해 주면 됩니다!!

참고로 generate 문은 HW에서 사용할 수 있는 반복문 정도로 이해하면 됩니다. 그냥 for 문을 사용할 수도 있지만 Synthesis를 할 수 없기 때문에 반복적인 HW configuration을 기술하려면 generate를 써야 하며 상당히 중요한 syntax입니다.

6. Result

top level에서 시작하는 부분의 timing diagram입니다. 자세히 보실 필요는 없고, timing 잘 맞고 error도 없게 나타났습니다!

이건 전체 view timing diagram인데, 깔끔하게 연산 잘 되네요.

신기한 건 reference data를 만들어놔서 error check를 하도록 testbench를 작성했다는 겁니다. system verilog 초고수이신듯..

어쨌든 Project 1 Done!!

'EE > AI Silicon System' 카테고리의 다른 글

[AI Silicon System] Project 0: Overview  (0) 2023.05.03