1 module dvec.matrix; 2 3 import std.traits : isNumeric, isFloatingPoint; 4 import dvec.vector; 5 6 struct Mat(T, size_t rowCount, size_t colCount) if (isNumeric!T && rowCount > 0 && colCount > 0) { 7 public T[rowCount * colCount] data; 8 9 public this(T[rowCount * colCount] elements) { 10 data[0 .. $] = elements[0 .. $]; 11 } 12 13 public this(T[] elements...) { 14 data[0 .. $] = elements[0 .. $]; 15 } 16 17 public this(Mat!(T, rowCount, colCount) other) { 18 static foreach (i; 0 .. data.length) data[i] = other.data[i]; 19 } 20 21 public this(T value) { 22 static foreach (i; 0 .. data.length) data[i] = value; 23 } 24 25 private size_t convertToIndex(size_t i, size_t j) { 26 return colCount * i + j; 27 } 28 29 public T opIndex(size_t i, size_t j) { 30 return data[convertToIndex(i, j)]; 31 } 32 33 public void opIndexAssign(T value, size_t i, size_t j) { 34 data[convertToIndex(i, j)] = value; 35 } 36 37 public Vec!(T, colCount) getRow(size_t row) { 38 size_t idx = convertToIndex(row, 0); 39 return Vec!(T, colCount)(data[idx .. idx + colCount]); 40 } 41 42 public void setRow(size_t row, Vec!(T, colCount) vector) { 43 size_t idx = convertToIndex(row, 0); 44 data[idx .. idx + colCount] = vector.data; 45 } 46 47 public Vec!(T, rowCount) getCol(size_t col) { 48 Vec!(T, rowCount) v; 49 static foreach (i; 0 .. rowCount) { 50 v[i] = this[i, col]; 51 } 52 return v; 53 } 54 55 public void setCol(size_t col, Vec!(T, rowCount) vector) { 56 static foreach (i; 0 .. rowCount) { 57 this[i, col] = vector[i]; 58 } 59 } 60 61 public void add(Mat!(T, rowCount, colCount) other) { 62 static foreach (i; 0 .. data.length) data[i] += other.data[i]; 63 } 64 65 public void sub(Mat!(T, rowCount, colCount) other) { 66 static foreach (i; 0 .. data.length) data[i] -= other.data[i]; 67 } 68 69 public void mul(T factor) { 70 static foreach (i; 0 .. data.length) data[i] *= factor; 71 } 72 73 public void div(T factor) { 74 static foreach (i; 0 .. data.length) data[i] /= factor; 75 } 76 77 public Mat!(T, colCount, rowCount) transpose() { 78 Mat!(T, colCount, rowCount) m; 79 static foreach (i; 0 .. rowCount) { 80 static foreach (j; 0 .. colCount) { 81 m[j, i] = this[i, j]; 82 } 83 } 84 return m; 85 } 86 87 /** 88 * Computes the matrix multiplication of `this * other`. 89 * Params: 90 * other = The matrix to multiply with this one. 91 * Returns: The resultant matrix. 92 */ 93 public Mat!(T, rowCount, otherColCount) mul(T, size_t otherRowCount, size_t otherColCount) 94 (Mat!(T, otherRowCount, otherColCount) other) { 95 Mat!(T, rowCount, otherColCount) m; 96 T sum; 97 static foreach (i; 0 .. rowCount) { 98 static foreach (j; 0 .. otherColCount) { 99 sum = 0; 100 static foreach (k; 0 .. colCount) { 101 sum += this[i, k] * other[k, j]; 102 } 103 m[i, j] = sum; 104 } 105 } 106 return m; 107 } 108 109 /** 110 * Multiplies a vector against this matrix. 111 * Params: 112 * vector = The vector to multiply. 113 * Returns: The resultant transformed vector. 114 */ 115 public Vec!(T, rowCount) mul(Vec!(T, colCount) vector) { 116 Vec!(T, rowCount) result; 117 T sum; 118 static foreach (i; 0 .. rowCount) { 119 sum = 0; 120 static foreach (j; 0 .. colCount) { 121 sum += this[i, j] * vector[j]; 122 } 123 result[i] = sum; 124 } 125 return result; 126 } 127 128 public void rowSwitch(size_t rowI, size_t rowJ) { 129 auto r = getRow(rowI); 130 setRow(rowJ, r); 131 } 132 133 public void rowMultiply(size_t row, T factor) { 134 size_t idx = convertToIndex(row, 0); 135 static foreach (i; 0 .. colCount) { 136 data[idx + i] *= factor; 137 } 138 } 139 140 public void rowAdd(size_t rowI, T factor, size_t rowJ) { 141 auto row = getRow(rowJ); 142 row.mul(factor); 143 setRow(rowI, row); 144 } 145 146 public Mat!(T, rowCount - n, colCount - m) subMatrix(size_t n, size_t m)(size_t[n] rows, size_t[m] cols) 147 if (rowCount - n > 0 && colCount - m > 0) { 148 // TODO: Improve efficiency with static stuff. 149 Mat!(T, rowCount - n, colCount - m) sub; 150 size_t subIdx = 0; 151 foreach (idx; 0 .. data.length) { 152 size_t row = idx / colCount; 153 size_t col = idx % colCount; 154 bool skip = false; 155 foreach (r; rows) { 156 if (r == row) { 157 skip = true; 158 break; 159 } 160 } 161 if (!skip) { 162 foreach (c; cols) { 163 if (c == col) { 164 skip = true; 165 break; 166 } 167 } 168 } 169 if (!skip) { 170 sub.data[subIdx++] = this[row, col]; 171 } 172 } 173 return sub; 174 } 175 176 static if (rowCount == colCount) { 177 alias N = rowCount; 178 179 public static Mat!(T, N, N) identity() { 180 Mat!(T, N, N) m; 181 static foreach (i; 0 .. N) { 182 static foreach (j; 0 .. N) { 183 m[i, j] = i == j ? 1 : 0; 184 } 185 } 186 return m; 187 } 188 189 public T det() { 190 static if (N == 1) { 191 return data[0]; 192 } else static if (N == 2) { 193 return data[0] * data[3] - data[1] * data[2]; 194 } else { 195 // Laplace expansion, taking i = 0. 196 T sum = 0; 197 static foreach (j; 0 .. N) { 198 sum += (j % 2 == 0 ? 1 : -1) * this[0, j] * this.subMatrix([0], [j]).det(); 199 } 200 return sum; 201 } 202 } 203 204 public bool invertible() { 205 return det() != 0; 206 } 207 208 public Mat!(T, N, N) cofactor() { 209 static if (N == 1) { 210 return Mat!(T, N, N)(data[0]); 211 } else { 212 Mat!(T, N, N) c; 213 static foreach (i; 0 .. N) { 214 static foreach (j; 0 .. N) { 215 c[i, j] = ((i + j) % 2 == 0 ? 1 : -1) * this.subMatrix([i], [j]).det(); 216 } 217 } 218 return c; 219 } 220 } 221 222 public Mat!(T, N, N) adjugate() { 223 return cofactor().transpose(); 224 } 225 226 public Mat!(T, N, N) inv() { 227 auto m = adjugate(); 228 m.div(det()); 229 return m; 230 } 231 } 232 } 233 234 // Aliases for common matrix types. 235 alias Mat2f = Mat!(float, 2, 2); 236 alias Mat3f = Mat!(float, 3, 3); 237 alias Mat4f = Mat!(float, 4, 4); 238 239 alias Mat2d = Mat!(double, 2, 2); 240 alias Mat3d = Mat!(double, 3, 3); 241 alias Mat4d = Mat!(double, 4, 4); 242 243 unittest { 244 import std.stdio; 245 import dvec.vector; 246 247 auto m1 = Mat3d(); 248 assert(m1.data.length == 9); 249 250 auto m2 = Mat!(double, 2, 3)([1, 2, 3, 0, -6, 7]); 251 auto m3 = m2.transpose(); 252 assert(m3.data == [1, 0, 2, -6, 3, 7]); 253 254 auto m4 = Mat2d([1, 2, 3, 4]); 255 assert(m4.getRow(0).data == [1, 2]); 256 assert(m4.getRow(1).data == [3, 4]); 257 assert(m4.getCol(0).data == [1, 3]); 258 assert(m4.getCol(1).data == [2, 4]); 259 auto m5 = m4.mul(Mat2d([0, 1, 0, 0])); 260 assert(m5.data == [0, 1, 0, 3]); 261 262 auto m6 = Mat2d.identity(); 263 assert(m6.data == [1, 0, 0, 1]); 264 auto m7 = Mat3d.identity(); 265 assert(m7.data == [1, 0, 0, 0, 1, 0, 0, 0, 1]); 266 267 auto m8 = Mat!(double, 2, 3)([1, -1, 2, 0, -3, 1]); 268 Vec3d v1 = Vec3d(2, 1, 0); 269 assert(m8.mul(v1).data == [1, -3]); 270 271 auto m9 = Mat!(double, 3, 4)([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); 272 auto m10 = m9.subMatrix([2], [1]); 273 assert(m10.data == [1, 3, 4, 5, 7, 8]); 274 275 assert(Mat2d([3, 7, 1, -4]).det == -19); 276 assert(Mat2d([1, 2, 3, 4]).det == -2); 277 assert(Mat3d([1, 2, 3, 4, 5, 6, 7, 8, 9]).det == 0); 278 279 assert(Mat2d.identity().inv() == Mat2d.identity()); 280 assert(Mat3f.identity().inv() == Mat3f.identity()); 281 assert(Mat2d(4, 7, 2, 6).inv() == Mat2d(0.6, -0.7, -0.2, 0.4)); 282 assert(Mat2d(-3, 1, 5, -2).inv() == Mat2d(-2, -1, -5, -3)); 283 assert(Mat3d(1, 3, 3, 1, 4, 3, 1, 3, 4).inv() == Mat3d(7, -3, -3, -1, 1, 0, -1, 0, 1)); 284 }