概要
numpy version=1.20からnumpy.typing
が提供されています。
型アノテーションを記述する際に、numpy.ndarray
で指定するのと、numpy.typing.NDArray
で指定するのは何が違うのかを整理します。
numpy.typing.NDArrayとは何か
Typing (numpy.typing) — NumPy v1.26 Manualによると、
A generic version of np.ndarray[Any, np.dtype[+ScalarType]].
と記載されています。
numpy==1.26.2におけるnumpy.typing.NDArray
の実装を確認してみます。
import numpy as np import numpy.typing as npt npt.NDArray
numpy.ndarray[typing.Any, numpy.dtype[+_ScalarType_co]]
コード上では以下のように実装されています。
from numpy import ( ndarray, dtype, generic, ... ) ... _ScalarType_co = TypeVar("_ScalarType_co", bound=generic, covariant=True) ... NDArray = ndarray[Any, dtype[_ScalarType_co]]
numpy.typing.NDArray
はndarray[Any, dtype[_ScalarType_co]]
の型エイリアスということが分かります。
また、_ScalarType_co
はTypeVar("_ScalarType_co", bound=generic, covariant=True)
という型変数です。
bound=generic
と指定しています。Scalars — NumPy v1.26 Manualによると、numpy.generic
はnumpyにおけるスカラー型の基底クラスとのことです。_ScalarType_co
はnumpy.generic
の部分型(要はスカラー型)を受け取ります。covariant=True
と指定しています。A
がB
の部分型の場合、dtype[A]
はdtype[B]
の部分型と見なされます。
ndarray[Any, dtype[_ScalarType_co]]
という実装について、Any
とdtype[_ScalarType_co]]
はそれぞれ何に対応しているのかを確認していきます。
コード上では、ndarray
は以下のように定義されています。
_DType_co = TypeVar("_DType_co", covariant=True, bound=dtype[Any]) ... # TODO: Set the `bound` to something more suitable once we # have proper shape support _ShapeType = TypeVar("_ShapeType", bound=Any) ... class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]): __hash__: ClassVar[None] ...
_ShapeType
はndarray
のshape情報が与えられます。
ただ、現状の_ShapeType
の実装ではbound=Any
となっているので、Any
の部分型(要はなんでも)を受け取れてしまいます。
TODOコメントにもあるように、shapeチェックができるようになると良いですね。
…というわけで、型アノテーションにおいて2つを以下のように整理しました。
ndarray
ではShapeTypeとDtypeの2つを指定できる。ただし、現状ShapeTypeには任意の型が渡せてしまう。NDArray
はndarray[Any, dtype[_ScalarType_co]]
と等価。
個人的には、型アノテーションがシンプルになることからNDArray
を使うのが良いんじゃないかなと思います1。
NDArrayを利用した型アノテーション
実際にNDArray
を使って型アノテーションしてみましょう。
例えば以下のように型アノテーションが可能です。
import numpy as np import numpy.typing as npt def add_one(xs: npt.NDArray) -> npt.NDArray: return xs + 1 xs = np.array([1, 2, 3]) add_one(xs)
$ mypy exp.py Success: no issues found in 1 source file
ただし上記ではNDArray
のdtype部分を指定していないため、mypy --strict
もしくはmypy --disallow-any-generics
では以下のように怒られます。
$ mypy --disallow-any-generics exp.py exp.py:5: error: Missing type parameters for generic type "npt.NDArray" [type-arg] Found 1 error in 1 file (checked 1 source file)
以下のようにdtypeを指定してあげればOKです。
def add_one(xs: npt.NDArray[np.int32]) -> npt.NDArray[np.int32]: return xs + 1
誤ってNDArray[np.float32]
型の値を渡すと、ちゃんとmypyが怒ってくれます。
xs: npt.NDArray[np.float32] = np.array([1, 2, 3]) add_one(xs)
$ mypy --disallow-any-generics exp.py exp.py:10: error: Argument 1 to "add_one" has incompatible type "ndarray[Any, dtype[floating[_32Bit]]]"; expected "ndarray[Any, dtype[signedinteger[_32Bit]]]" [arg-type] Found 1 error in 1 file (checked 1 source file)
まとめ
numpy.typing.NDArray
の実装および使い方について見てきました。
numpy.typingは他にもArrayLike
やDTypeLike
などの型を提供しており、うまく使ってあげることで素敵な型アノテーションができそうです。