-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathamx.matrix.h
148 lines (121 loc) · 3.91 KB
/
amx.matrix.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#pragma once
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "amx.print.h"
#include "amx.tools.h"
#include "amx.types.h"
namespace amx {
// There are three types of Matrixes/Tiles:
// - vanilla tile/matrix: memory has row-major layout. If the source data is also row-major,
// you can start reading from an offset and read with stride 64.
// - transposed tile: the tile is annotated as transposed. All operations on this tile work as
// if the tile is a vanilla tile with the content (1024 bytes) transposed in memory. Making
// a tile tile transposed by setting this property prevent expensive memory shuffling.
// - amx_layout: the tile has the memory layout needed by AMX TMUL operations second src tile.
//
// Eg. to multiply two source matrices A[N][K] and B[K][M] (N<=16, M<=16, K<=64) load matrix A
// into vanilla tile T1. If matrix B is vanilla matrix, it needs to be transposed and put into
// amx_layout, this tile T2 can be used with instruction 'tdpbssd' to yield C[N][M]
//
template <typename T>
class Matrix {
private:
std::vector<T> data_;
void init(int n_cols, int n_rows) {
this->n_cols_ = n_cols;
this->n_rows_ = n_rows;
this->data_ = std::vector<T>(this->n_cols_ * this->n_rows_);
}
public:
int n_cols_;
int n_rows_;
Matrix() : Matrix(0, 0) {
init(0, 0);
}
explicit Matrix(int n_cols, int n_rows, T value)
{
init(n_cols, n_rows);
fill(value);
}
explicit Matrix(int n_cols, int n_rows)
{
init(n_cols, n_rows);
}
inline bool operator==(const Matrix& other) const noexcept
{
return
(this->n_cols_ == other.n_cols_) &&
(this->n_rows_ == other.n_rows_) &&
(this->data_ == other.data_);
}
Matrix& operator=(Matrix other)
{
this->n_rows_ = other.n_rows_;
this->n_cols_ = other.n_cols_;
this->data_ = std::vector<T>(other.data_);
return *this;
}
[[nodiscard]] int get_n_cols() const {
return this->n_cols_;
}
[[nodiscard]] int get_n_rows() const {
return this->n_rows_;
}
[[nodiscard]] int pos(int col, int row) const noexcept {
if constexpr (DEBUG) {
if (col > this->n_cols_) {
std::cerr << "Matrix.pos: invalid dimensions: column > n_cols_ (column = " << col << ", n_cols_ = " << this->n_cols_ << ")" << std::endl;
__debugbreak();
}
if (row > this->n_rows_) {
std::cerr << "Matrix.pos: invalid dimensions: row > n_rows_ (row = " << row << ", n_rows_ = " << this->n_rows_ << ")" << std::endl;
__debugbreak();
}
}
return (row * this->n_cols_) + col;
}
[[nodiscard]] const T& get(int col, int row) const {
const int idx = this->pos(col, row);
return this->data_.at(idx);
}
[[nodiscard]] T& get(int col, int row) {
const int idx = this->pos(col, row);
return this->data_.at(idx);
}
[[nodiscard]] const T& get(MatrixKey key) const {
return this->get(get_col(key), get_row(key));
}
[[nodiscard]] T& get(MatrixKey key) {
return this->get(get_col(key), get_row(key));
}
void set(int col, int row, const T& value) {
const int idx = this->pos(col, row);
this->data_.at(idx) = value;
}
void set(int col, int row, T&& value) {
const int idx = this->pos(col, row);
this->data_.at(idx) = std::move(value);
}
void fill(T value) {
for (int col = 0; col < this->n_cols_; ++col) {
for (int row = 0; row < this->n_rows_; ++row) {
this->set(col, row, value);
}
}
}
[[nodiscard]] std::string pretty_print(bool colour, tools::PrintType pt) const {
std::stringstream ss;
ss << "(#columns = " << this->n_cols_ << "; #rows = " << this->n_rows_ << ")" << std::endl;
ss << ((colour) ? "\u001b[0m" : ""); // reset colour
for (int row = 0; row < this->n_rows_; ++row) {
for (int column = 0; column < this->n_cols_; ++column) {
ss << amx::tools::pretty_print_value(this->get(column, row), column, colour, pt);
}
ss << std::endl;
}
return ss.str();
}
};
}