Recently, Ben Newhouse released a TypeScript-based implementation of GPT called potatogpt
. Although the performance may be slow, it contains a very interesting approach to type-checking tensor arithmetic. This approach eliminates the need to run your code to verify whether operations are allowed or to keep track of the sizes of tensors in your head.
The implementation is quite complex, employing several advanced TypeScript techniques. In order to make it more accessible and easier to understand, I’ve attempted to simplify and explain the implementation with clarifying comments below.
Finally, I show how this approach allows us to easily create type-safe versions of functions like zip
and matmul
.
Exact dimensions
In order that Tensor
s can have exact dimensions we need to support only numeric literals (e.g. 16
, 768
, etc) for sizes known at compile time, and “branded types” for sizes only known at runtime. We must disallow non-literal number
types or unions of number
s (e.g. 16 | 768
) as if these get introduced into an application, data produced using these would also lack exact dimensions.
typescript
// We check whether `T` is a numeric literal by checking that `number`// does not extend from `T` but that `T` does extend from `number`.typeIsNumericLiteral <T > = number extendsT ? false:T extends number? true: false;// In order to support runtime-determined sizes we use a "branded type"// to give these dimensions labels that they can be type-checked with// and a function `Var` to generate values with this type.export typeVar <Label extends string> = number & {label :Label };export constVar = <Label extends string>(size : number,label :Label ) => {returnsize asVar <Label >;};typeIsVar <T > =T extendsVar <string> ? true : false;// For type-checking of tensors to work they must only ever be// created using numeric literals (e.g. `5`) or `Var<string>`// and never from types like `number` or `1 | 2 | 3`.typeIsNumericLiteralOrVar <T extends number |Var <string>> =And <// We disallow `T` to be a union of types.Not <IsUnion <T >>,Or <// We allow `T` to be a numeric literal but not a number.IsNumericLiteral <T >,// We allow `T` to be a `Var`.IsVar <T >>>;// UtilitiestypeAnd <A ,B > =A extends true ? (B extends true ? true : false) : false;typeOr <A ,B > =A extends true ? true :B extends true ? true : false;typeNot <A > =A extends true ? false : true;// `IsUnion` is based on the principle that a union like `A | B` does not// extend an intersection like `A & B`. The conditional type uses a// "tuple trick" technique that avoids distributing the type `T` over// `UnionToIntersection` by wrapping the type into a one-element tuple.// This means that if `T` is `'A' | 'B'` the expression is evaluated// as `['A' | 'B'] extends [UnionToIntersection<'A' | 'B'>]` instead of// `'A' | 'B' extends UnionToIntersection<'A'> | UnionToIntersection<'B'>`.typeIsUnion <T > = [T ] extends [UnionToIntersection <T >] ? false : true;// `UnionToIntersection` takes a union type and uses a "distributive// conditional type" to map over each element of the union and create a// series of function types with each element as their argument. It then// infers the first argument of each of these functions to create a new// type that is the intersection of all the types in the original union.typeUnionToIntersection <Union > = (Union extends unknown ? (distributedUnion :Union ) => void : never) extends (mergedIntersection : inferIntersection ) => void?Intersection : never;
If you need to, you can read further on the more advanced TypeScript techniques here:
Tensor
We can then implement a type-safe Tensor
with a unique constraint: the dimensions must be specified using numeric literals or “branded types”. This approach pushes the limits of TypeScript’s standard type-checking capabilities and requires a non-idiomatic usage of conditional types to represent these errors. Note that, we diverged from Ben’s original implementation by enforcing this dimensional constraint at the argument-level instead of doing so at the return-level with a conditional return type that produces an invalid tensor. The downside of this is that you must use as const
on the shape
argument to prevent TypeScript from widening the literal types to number
.
typescript
export typeDimension = number |Var <string>;export typeTensor <Shape extends readonlyDimension []> = {data :Float32Array ;shape :Shape ;};export functiontensor <constShape extends readonlyDimension []>(shape :AssertShapeEveryElementIsNumericLiteralOrVar <Shape >,init ?: number[]):Tensor <Shape > {return {data :init ? newFloat32Array (init ): newFloat32Array ((shape asShape ).reduce ((a ,b ) =>a *b , 1)),shape :shape asShape ,};}// `ArrayEveryElementIsNumericLiteralOrVar` is similar to JavaScript's// `Array#every` in that it checks that a particular condition is true of// every element in an array and returns `true` if this is the case. In// TypeScript we have to hardcode our condition (`IsNumericLiteralOrVar`)// as we do not yet have higher-kinded generic types that can take in// other generic types and apply these.//// In the code below we create a "mapped object type" from an array type// and then apply the condition to each value in the mapped object type.// We then use a conditional type to check whether the type outputted// extends from a type in which the value at every key is `true`.typeArrayEveryElementIsNumericLiteralOrVar <T extendsReadonlyArray <number |Var <string>>> =T extendsReadonlyArray <unknown>? { [K in keyofT ]:IsNumericLiteralOrVar <T [K ]> } extends {[K in keyofT ]: true;}? true: false: false;typeInvalidArgument <T > = readonly [never,T ];typeAssertShapeEveryElementIsNumericLiteralOrVar <T extendsReadonlyArray <number |Var <string>>> = true extendsArrayEveryElementIsNumericLiteralOrVar <T >?T :ReadonlyArray <InvalidArgument <"The `shape` argument must be marked `as const` and only contain number literals or branded types.">>;// TestsconstfourDimensionalTensorWithStaticSizes =tensor ([10, 100, 1000, 10000,] asconst );constthreeDimensionalTensorWithRuntimeSize =tensor ([5,Var (3, "dim"),10,] asconst );constinvalidTensor1 =tensor ([10, 100, 1000, 10000]);constType 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.invalidTensor2 =tensor ([10 as number,100 ,1000 ,10000 ] asconst );
Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.
Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.
Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.2322
2322
2322
2322Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.
Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.
Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.
Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.constType 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.invalidTensor3 =tensor ([5 , 3 as 3 | 6 | 9,10 ] asconst );
Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'. Type 'number' is not assignable to type 'readonly [never, "The `shape` argument must be marked `as const` and only contain number literals or branded types."]'.
Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.2322
2322
2322Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.
Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'. Type 'number' is not assignable to type 'readonly [never, "The `shape` argument must be marked `as const` and only contain number literals or branded types."]'.
Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.
If you need to, you can read further on the more advanced TypeScript techniques here:
Matrix
typescript
functionisDimensionArray (maybeDimensionArray : any):maybeDimensionArray is readonlyDimension [] {return (Array .isArray (maybeDimensionArray ) &&maybeDimensionArray .some ((d ) => typeofd === "number"));}functionis2DArray (maybe2DArray : any):maybe2DArray is number[][] {return (Array .isArray (maybe2DArray ) &&maybe2DArray .some ((row ) =>Array .isArray (row )));}functionflat <T >(arr :T [][]):T [] {letresult :T [] = [];for (leti = 0;i <arr .length ;i ++) {result .push .apply (result ,arr [i ]);}returnresult ;}export typeMatrix <Rows extendsDimension ,Columns extendsDimension > =Tensor <readonly [Rows ,Columns ]>;export functionmatrix <constTwoDArray extendsReadonlyArray <ReadonlyArray <number>>>(init :TwoDArray ):Matrix <TwoDArray ["length"],TwoDArray [0]["length"]>;export functionmatrix <constShape extends readonly [Dimension ,Dimension ]>(shape :AssertShapeEveryElementIsNumericLiteralOrVar <Shape >,init ?: number[]):Matrix <Shape [0],Shape [1]>;export functionmatrix <constShape extends readonly [Dimension ,Dimension ]>(shape :AssertShapeEveryElementIsNumericLiteralOrVar <Shape >,init ?: number[]):Matrix <Shape [0],Shape [1]> {letresolvedShape : readonly [any, any];if (isDimensionArray (shape )) {resolvedShape =shape ;} else if (is2DArray (shape )) {resolvedShape = [shape .length ,shape [0].length ];init =flat (shape );} else {throw newError ("Invalid shape type for matrix.");}returntensor (resolvedShape ,init );}// TestsconstmatrixWithStaticSizes =matrix ([25, 50] asconst );constmatrixWithRuntimeSize =matrix ([10,Var (100, "configuredDimensionName"),] asconst );constmatrixWithSizeFromData =matrix ([[1, 2, 3],[4, 5, 6],[7, 8, 9],]);constinvalidMatrix1 =matrix ([25, 50]);constNo overload matches this call. Overload 2 of 2, '(shape: readonly InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">[], init?: number[] | undefined): Matrix<...>', gave the following error. Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'. Overload 2 of 2, '(shape: readonly InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">[], init?: number[] | undefined): Matrix<...>', gave the following error. Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.2769No overload matches this call. Overload 2 of 2, '(shape: readonly InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">[], init?: number[] | undefined): Matrix<...>', gave the following error. Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'. Overload 2 of 2, '(shape: readonly InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">[], init?: number[] | undefined): Matrix<...>', gave the following error. Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'.invalidMatrix2 =matrix ([25 as number, 50] asconst );constNo overload matches this call. Overload 2 of 2, '(shape: readonly InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">[], init?: number[] | undefined): Matrix<...>', gave the following error. Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'. Overload 2 of 2, '(shape: readonly InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">[], init?: number[] | undefined): Matrix<...>', gave the following error. Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'. Type 'number' is not assignable to type 'readonly [never, "The `shape` argument must be marked `as const` and only contain number literals or branded types."]'.2769No overload matches this call. Overload 2 of 2, '(shape: readonly InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">[], init?: number[] | undefined): Matrix<...>', gave the following error. Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'. Overload 2 of 2, '(shape: readonly InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">[], init?: number[] | undefined): Matrix<...>', gave the following error. Type 'number' is not assignable to type 'InvalidArgument<"The `shape` argument must be marked `as const` and only contain number literals or branded types.">'. Type 'number' is not assignable to type 'readonly [never, "The `shape` argument must be marked `as const` and only contain number literals or branded types."]'.invalidMatrix3 =matrix ([10, 100 as 100 | 115] asconst );
Vector
typescript
typeAssertSizeIsNumericLiteralOrVar <T extendsDimension > =true extendsIsNumericLiteralOrVar <T >?T :InvalidArgument <"The `size` argument must only contain number literals or branded types.">;export typeRowVector <Size extendsDimension > =Tensor <readonly [1,Size ]>;export typeVector <Size extendsDimension > =RowVector <Size >;export functionvector <constOneDArray extends readonlyDimension []>(init :OneDArray ):Vector <OneDArray ["length"]>;export functionvector <constSize extendsDimension >(size :AssertSizeIsNumericLiteralOrVar <Size >,init ?: number[]):Vector <Size >;export functionvector <constSize extendsDimension >(size :AssertSizeIsNumericLiteralOrVar <Size >,init ?: number[]):Vector <Size > {letshape : readonly [1, any];if (typeofsize === "number") {shape = [1,size ];} else if (Array .isArray (size )) {shape = [1,size .length ];init =size ;} else {throw newError ("Invalid size type for vector.");}returntensor (shape ,init );}// TestsconstvectorWithStaticSize =vector (2);constvectorWithRuntimeSize =vector (Var (4, "configuredDimensionName"));constvectorWithSizeFromData =vector ([1, 2, 3]);constNo overload matches this call. Overload 1 of 2, '(init: readonly Dimension[]): Vector<number>', gave the following error. Argument of type 'number' is not assignable to parameter of type 'readonly Dimension[]'. Overload 2 of 2, '(size: InvalidArgument<"The `size` argument must only contain number literals or branded types.">, init?: number[] | undefined): Vector<number>', gave the following error. Argument of type 'number' is not assignable to parameter of type 'InvalidArgument<"The `size` argument must only contain number literals or branded types.">'.2769No overload matches this call. Overload 1 of 2, '(init: readonly Dimension[]): Vector<number>', gave the following error. Argument of type 'number' is not assignable to parameter of type 'readonly Dimension[]'. Overload 2 of 2, '(size: InvalidArgument<"The `size` argument must only contain number literals or branded types.">, init?: number[] | undefined): Vector<number>', gave the following error. Argument of type 'number' is not assignable to parameter of type 'InvalidArgument<"The `size` argument must only contain number literals or branded types.">'.invalidVector1 =vector (2 as number);constNo overload matches this call. Overload 1 of 2, '(init: readonly Dimension[]): Vector<number>', gave the following error. Argument of type 'number' is not assignable to parameter of type 'readonly Dimension[]'. Type 'number' is not assignable to type 'readonly Dimension[]'. Overload 2 of 2, '(size: InvalidArgument<"The `size` argument must only contain number literals or branded types.">, init?: number[] | undefined): Vector<100 | 115>', gave the following error. Argument of type 'number' is not assignable to parameter of type 'InvalidArgument<"The `size` argument must only contain number literals or branded types.">'. Type 'number' is not assignable to type 'readonly [never, "The `size` argument must only contain number literals or branded types."]'.2769No overload matches this call. Overload 1 of 2, '(init: readonly Dimension[]): Vector<number>', gave the following error. Argument of type 'number' is not assignable to parameter of type 'readonly Dimension[]'. Type 'number' is not assignable to type 'readonly Dimension[]'. Overload 2 of 2, '(size: InvalidArgument<"The `size` argument must only contain number literals or branded types.">, init?: number[] | undefined): Vector<100 | 115>', gave the following error. Argument of type 'number' is not assignable to parameter of type 'InvalidArgument<"The `size` argument must only contain number literals or branded types.">'. Type 'number' is not assignable to type 'readonly [never, "The `size` argument must only contain number literals or branded types."]'.invalidVector2 =vector (100 as 100 | 115);
zip
Once we have a Vector
and Matrix
type defined, we can use these to write a type-safe zip
function that combines two Vector
s of the same length into a Matrix
of [VectorLength, 2]
, like so:
typescript
/*** The `zip` function combines two vectors of the same length into a matrix* where each row contains a pair of corresponding elements from the input* vectors. The output matrix's data is stored in a `Float32Array` with an* interleaved arrangement of elements (row-major storage order) for efficient* access.** Example:* Input vectors: [a1, a2, a3] and [b1, b2, b3]* Output matrix:* | a1 b1 |* | a2 b2 |* | a3 b3 |** Memory layout in Float32Array: [a1, b1, a2, b2, a3, b3]*/functionzip <SameVector extendsVector <Dimension >>(a :SameVector ,b :SameVector ):Matrix <SameVector ["shape"][1], 2> {if (a .shape [1] !==b .shape [1]) {throw newError (`zip cannot operate on different length vectors; ${a .shape [1]} !== ${b .shape [1]}`);}constlength =a .shape [1];constresultData : number[] = [];for (leti = 0;i <length ;i ++) {resultData .push (a .data [i ],b .data [i ]);}returnmatrix ([length as any, 2] asconst ,resultData );}// TestsconstthreeElementVector1 =vector ([1, 2, 3]);constthreeElementVector2 =vector ([4, 5, 6]);constfourElementVector1 =vector ([7, 8, 9, 10]);constzipped =zip (threeElementVector1 ,threeElementVector2 );constArgument of type 'Vector<4>' is not assignable to parameter of type 'Vector<3>'. Type '4' is not assignable to type '3'.2345Argument of type 'Vector<4>' is not assignable to parameter of type 'Vector<3>'. Type '4' is not assignable to type '3'.zippedError =zip (threeElementVector1 ,); fourElementVector1 constthreeElementVector3 =vector (Var (3, "three"), [1, 2, 3]);constthreeElementVector4 =vector (Var (3, "three"), [5, 10, 15]);constfourElementVector2 =vector (Var (4, "four"), [10, 11, 12, 13]);constzipped2 =zip (threeElementVector3 ,threeElementVector4 );constArgument of type 'Vector<Var<"four">>' is not assignable to parameter of type 'Vector<Var<"three">>'. Type 'Var<"four">' is not assignable to type 'Var<"three">'. Type 'Var<"four">' is not assignable to type '{ label: "three"; }'. Types of property 'label' are incompatible. Type '"four"' is not assignable to type '"three"'.2345Argument of type 'Vector<Var<"four">>' is not assignable to parameter of type 'Vector<Var<"three">>'. Type 'Var<"four">' is not assignable to type 'Var<"three">'. Type 'Var<"four">' is not assignable to type '{ label: "three"; }'. Types of property 'label' are incompatible. Type '"four"' is not assignable to type '"three"'.zippedError2 =zip (threeElementVector3 ,); fourElementVector2
matmul
Finally, functions like matmul
that expect two operands with different but compatible shapes, can be implemented using the same techniques:
typescript
functionmatmul <RowsA extendsDimension ,SharedDimension extendsDimension ,ColumnsB extendsDimension >(a :Matrix <RowsA ,SharedDimension >,b :IsNumericLiteralOrVar <SharedDimension > extends true?Matrix <SharedDimension ,ColumnsB >:InvalidArgument <"The rows dimension of the `b` matrix must match the columns dimension of the `a` matrix.">):Matrix <RowsA ,ColumnsB > {constaMatrix =a ;constbMatrix =b asMatrix <SharedDimension ,ColumnsB >;const [aRows ,aCols ] =aMatrix .shape ;const [bRows ,bCols ] =bMatrix .shape ;if (aCols !==bRows ) {throw newError ("The rows dimension of the `b` matrix must match the columns dimension of the `a` matrix.");}constshape = [aRows ,bCols ] asAssertShapeEveryElementIsNumericLiteralOrVar <[RowsA ,ColumnsB ]>;constdata =Array <number>(aRows *bCols ).fill (0);for (letrowIndex = 0;rowIndex <aRows ;rowIndex ++) {for (letcolumnIndex = 0;columnIndex <bCols ;columnIndex ++) {letdotProduct = 0;for (letsharedDimensionIndex = 0;sharedDimensionIndex <aCols ;sharedDimensionIndex ++) {constrowCellFromA =aMatrix .data [rowIndex *aCols +sharedDimensionIndex ];constcolumnCellFromB =bMatrix .data [sharedDimensionIndex *bCols +columnIndex ];dotProduct +=rowCellFromA *columnCellFromB ;}data [rowIndex *bCols +columnIndex ] =dotProduct ;}}returnmatrix (shape ,data );}// Testsconsta =matrix ([2, 3] asconst );constb =matrix ([3, 2] asconst );constc =matrix ([7, 7] asconst );constvalidMatmul =matmul (a ,b );constArgument of type 'Matrix<7, 7>' is not assignable to parameter of type 'InvalidArgument<"The rows dimension of the `b` matrix must match the columns dimension of the `a` matrix.">'.2345Argument of type 'Matrix<7, 7>' is not assignable to parameter of type 'InvalidArgument<"The rows dimension of the `b` matrix must match the columns dimension of the `a` matrix.">'.invalidMatmul =matmul (a ,); c