手写矩阵乘法

本文最后更新于 2024年2月2日

\(XY\)

\[ \begin{aligned} \pmb Z_{rc}&=\pmb X_{rn}\pmb Y_{nc} \\&= \begin{bmatrix} x_{11} & \cdots & x_{1n} \\ \vdots & \ddots & \vdots \\ x_{r1} & \cdots & x_{rn} \\ \end{bmatrix} \begin{bmatrix} y_{11} & \cdots & y_{1c} \\ \vdots & \ddots & \vdots \\ y_{n1} & \cdots & y_{nc} \\ \end{bmatrix} \\&= \begin{bmatrix} (x_{11}y_{11}+\ldots+x_{1n}y_{n1}) & \cdots & (x_{11}y_{1c}+\ldots+x_{1n}y_{nc}) \\ \vdots & \ddots & \vdots \\ (x_{r1}y_{11}+\ldots+x_{rn}y_{n1}) & \cdots & (x_{r1}y_{1c}+\ldots+x_{rn}y_{nc}) \\ \end{bmatrix} \end{aligned} \]

void mulmm(double *_z, double const *_x, double const *_y, size_t row, size_t c_r, size_t col)
{
    double const *x, *x_;
    double const *y, *y_;
    double *z = _z, *z_ = _z + row * col;
    while (z < z_) { *z++ = 0; }
    for (x = _x; row--; _z = z) /* Z1 => Zr, X1 => Xr */
    {
        for (y = _y, x_ = x + c_r; x < x_; ++x) /* X11...X1r => Xn1...Xnr */
        {
            for (z = _z, y_ = y + col; y < y_; ++y) /* Y11...Y1c...Yn1...Ync */
            {
                *z++ += *x * *y;
            }
        }
    }
}

\(X^TY\)

\[ \begin{aligned} \pmb Z_{rc}&=\pmb X_{nr}^{T}\pmb Y_{nc} \\&= \begin{bmatrix} x_{11} & \cdots & x_{1r} \\ \vdots & \ddots & \vdots \\ x_{n1} & \cdots & x_{nr} \\ \end{bmatrix}^T \begin{bmatrix} y_{11} & \cdots & y_{1c} \\ \vdots & \ddots & \vdots \\ y_{n1} & \cdots & y_{nc} \\ \end{bmatrix} \\&= \begin{bmatrix} (x_{11}y_{11}+\ldots+x_{n1}y_{n1}) & \cdots & (x_{11}y_{1c}+\ldots+x_{n1}y_{nc}) \\ \vdots & \ddots & \vdots \\ (x_{1r}y_{11}+\ldots+x_{nr}y_{n1}) & \cdots & (x_{1r}y_{1c}+\ldots+x_{nr}y_{nc}) \\ \end{bmatrix} \\&= \begin{bmatrix} x_{11}y_{11} & \cdots & x_{11}y_{1c} \\ \vdots & \ddots & \vdots \\ x_{1r}y_{11} & \cdots & x_{1r}y_{1c} \\ \end{bmatrix}+\cdots+ \begin{bmatrix} x_{n1}y_{n1} & \cdots & x_{n1}y_{nc} \\ \vdots & \ddots & \vdots \\ x_{nr}y_{n1} & \cdots & x_{nr}y_{nc} \\ \end{bmatrix} \end{aligned} \]

void multm(double *_z, double const *_x, double const *_y, size_t c_r, size_t row, size_t col)
{
    double const *x, *x_;
    double const *y, *y_ = _y;
    double *z = _z, *z_ = _z + row * col;
    while (z < z_) { *z++ = 0; }
    for (x = _x; c_r--; _y = y_) /* [X11...X1r]^T[Y11...Y1c]+[Xn1...Xnr]^T[Yn1...Ync] */
    {
        for (z = _z, x_ = x + row; x < x_; ++x) /* X11...X1r => Xn1...Xnr */
        {
            for (y = _y, y_ = y + col; y < y_; ++y) /* Y11...Y1c => Yn1...Ync */
            {
                *z++ += *x * *y;
            }
        }
    }
}

\(XY^T\)

\[ \begin{aligned} \pmb Z_{rc}&=\pmb X_{rn}\pmb Y_{cn}^T \\&= \begin{bmatrix} x_{11} & \cdots & x_{1n} \\ \vdots & \ddots & \vdots \\ x_{r1} & \cdots & x_{rn} \\ \end{bmatrix} \begin{bmatrix} y_{11} & \cdots & y_{1n} \\ \vdots & \ddots & \vdots \\ y_{c1} & \cdots & y_{cn} \\ \end{bmatrix}^T \\&= \begin{bmatrix} (x_{11}y_{11}+\ldots+x_{1n}y_{1n}) & \cdots & (x_{11}y_{c1}+\ldots+x_{1n}y_{cn}) \\ \vdots & \ddots & \vdots \\ (x_{r1}y_{11}+\ldots+x_{rn}y_{1n}) & \cdots & (x_{r1}y_{c1}+\ldots+x_{rn}y_{cn}) \\ \end{bmatrix} \end{aligned} \]

void mulmt(double *_z, double const *_x, double const *_y, size_t row, size_t col, size_t c_r)
{
    double const *x, *x_ = _x;
    double const *y, *y_ = _y + c_r * col;
    double *z = _z, *z_ = _z + row * col;
    while (z < z_) { *z++ = 0; }
    for (z = _z; row--; _x = x_) /* Z1 => Zr, X1 => Xr */
    {
        for (y = _y; y < y_; ++z) /* Y11...Y1n...Yc1...Ycn */
        {
            for (x = _x, x_ = x + c_r; x < x_;) /* X11...X1n => Xr1...Xrn */
            {
                *z += *x++ * *y++;
            }
        }
    }
}

\(X^TY^T\)

\[ \begin{aligned} \pmb Z_{rc}&=\pmb X_{nr}^T\pmb Y_{cn}^T \\&= \begin{bmatrix} x_{11} & \cdots & x_{1r} \\ \vdots & \ddots & \vdots \\ x_{n1} & \cdots & x_{nr} \\ \end{bmatrix}^T \begin{bmatrix} y_{11} & \cdots & y_{1n} \\ \vdots & \ddots & \vdots \\ y_{c1} & \cdots & y_{cn} \\ \end{bmatrix}^T \\&= \begin{bmatrix} (x_{11}y_{11}+\ldots+x_{n1}y_{1n}) & \cdots & (x_{11}y_{c1}+\ldots+x_{n1}y_{cn}) \\ \vdots & \ddots & \vdots \\ (x_{1r}y_{11}+\ldots+x_{nr}y_{1n}) & \cdots & (x_{1r}y_{c1}+\ldots+x_{nr}y_{cn}) \\ \end{bmatrix} \end{aligned} \]

void multt(double *_z, double const *_x, double const *_y, size_t row, size_t c_r, size_t col)
{
    size_t n = c_r;
    double const *x, *x_;
    double const *y, *y_ = _y + col * c_r;
    double *z = _z, *z_ = _z + row * col;
    while (z < z_) { *z++ = 0; }
    for (x = _x; n--; ++_y) /* [X11...X1r]^T[Y11...Yc1]+[Xn1...Xnr]^T[Y1n...Ycn] */
    {
        for (z = _z, x_ = x + row; x < x_; ++x) /* X11...X1r => Xn1...Xnr */
        {
            for (y = _y; y < y_; y += c_r) /* Y11...Yc1 => Y1n...Ycn */
            {
                *z++ += *x * *y;
            }
        }
    }
}

参考


手写矩阵乘法
https://blog.tqfx.org/posts/Handwrite-Matrix-Multiplication/
作者
tqfx
发布于
2024年1月31日
许可协议