You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ya...@apache.org on 2023/07/26 19:58:48 UTC
[tvm] branch unity updated: [Unity][Module] Add Core Data Structure (#15398)
This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 91733636d0 [Unity][Module] Add Core Data Structure (#15398)
91733636d0 is described below
commit 91733636d05b7a8aec50616355cdf6554e08f83d
Author: Junru Shao <ju...@apache.org>
AuthorDate: Wed Jul 26 12:58:42 2023 -0700
[Unity][Module] Add Core Data Structure (#15398)
This PR introduces the core data structure for the new `nn.Module` API,
including:
- Tensor, a wrapper on top of relax.Expr whose struct_info is a TensorStructInfo,
providing more convenient access shape and dtype information.
Tensor is always symbolc and not bound to any concrete values.
- Parameter, a special tensor which could be bound or not bound to concrete values.
- Module, a container of nn.Parameters and sub nn.Modules.
- Effect, a non-user-facing class that encloses potential side effects, for example, IO,
impure external function callings, inplace mutation, etc.
---
python/tvm/relax/frontend/__init__.py | 5 +-
python/tvm/relax/frontend/{ => nn}/__init__.py | 6 +-
.../frontend/{__init__.py => nn/_tensor_op.py} | 9 +-
python/tvm/relax/frontend/nn/core.py | 419 +++++++++++++++++++++
4 files changed, 428 insertions(+), 11 deletions(-)
diff --git a/python/tvm/relax/frontend/__init__.py b/python/tvm/relax/frontend/__init__.py
index 4baf3195f0..c314245030 100644
--- a/python/tvm/relax/frontend/__init__.py
+++ b/python/tvm/relax/frontend/__init__.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-Frontends for constructing Relax programs, with the model importers
-"""
+"""Frontends for constructing Relax programs, with the model importers"""
+from . import nn
from .common import detach_params
diff --git a/python/tvm/relax/frontend/__init__.py b/python/tvm/relax/frontend/nn/__init__.py
similarity index 87%
copy from python/tvm/relax/frontend/__init__.py
copy to python/tvm/relax/frontend/nn/__init__.py
index 4baf3195f0..b7687eb924 100644
--- a/python/tvm/relax/frontend/__init__.py
+++ b/python/tvm/relax/frontend/nn/__init__.py
@@ -14,7 +14,5 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-Frontends for constructing Relax programs, with the model importers
-"""
-from .common import detach_params
+# pylint: disable=invalid-name
+"""A PyTorch-like API to build IRModules."""
diff --git a/python/tvm/relax/frontend/__init__.py b/python/tvm/relax/frontend/nn/_tensor_op.py
similarity index 87%
copy from python/tvm/relax/frontend/__init__.py
copy to python/tvm/relax/frontend/nn/_tensor_op.py
index 4baf3195f0..a5e4f9b0cb 100644
--- a/python/tvm/relax/frontend/__init__.py
+++ b/python/tvm/relax/frontend/nn/_tensor_op.py
@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-Frontends for constructing Relax programs, with the model importers
-"""
-from .common import detach_params
+"""Adding member operators to nn.Tensor."""
+
+
+class _TensorOp: # pylint: disable=too-few-public-methods
+ pass
diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py
new file mode 100644
index 0000000000..8f258dd81c
--- /dev/null
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -0,0 +1,419 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The core infra for nn.Module, which includes the following pieces:
+- Tensor, a wrapper on top of relax.Expr whose struct_info is a TensorStructInfo,
+ providing more convenient access shape and dtype information.
+ Tensor is always symbolc and not bound to any concrete values.
+- Parameter, a special tensor which could be bound or not bound to concrete values.
+- Module, a container of nn.Parameters and sub nn.Modules.
+- Effect, a non-user-facing class that encloses potential side effects, for example, IO,
+ impure external function callings, inplace mutation, etc.
+"""
+from collections import OrderedDict
+from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union
+
+import numpy as np
+
+from tvm import tir
+from tvm.runtime import Device, NDArray, ndarray
+
+from ... import expr as rx
+from ...block_builder import BlockBuilder
+from ...struct_info import ShapeStructInfo, TensorStructInfo
+from ._tensor_op import _TensorOp
+
+_DEFAULT_DTYPE = "float32"
+
+
+def get_default_dtype() -> str:
+ """Get the default parameter dtype if not specified. By default it is float32.
+
+ Returns
+ -------
+ dtype : str
+ The default dtype
+ """
+ return _DEFAULT_DTYPE
+
+
+def set_default_dtype(dtype: str) -> None:
+ """Set the default parameter dtype.
+
+ Parameters
+ ----------
+ dtype : str
+ The default dtype to be set
+ """
+ global _DEFAULT_DTYPE # pylint: disable=global-statement
+ _DEFAULT_DTYPE = dtype
+
+
+class Tensor(_TensorOp):
+ """A wrapper on top of relax.Expr whose struct_info is a TensorStructInfo, providing more
+ convenient access shape and dtype information. Tensor is always symbolc and not bound to any
+ concrete values. Shape and dtype inference is done eagerly upon tensor creation, i.e. when
+ operators are applied on tensors, the shape and dtype information is already available.
+ """
+
+ _expr: rx.Expr
+
+ def __init__(self, *, _expr: rx.Expr) -> None:
+ """Private constructor. Tensor is never supposed to be constructed directly by users."""
+
+ def _check_tensor(expr: rx.Expr) -> None:
+ assert expr.struct_info_ is not None
+ assert isinstance(expr.struct_info, TensorStructInfo)
+ assert expr.struct_info.ndim != -1
+ assert expr.struct_info.shape is not None
+ assert expr.struct_info.shape.struct_info_ is not None
+ assert isinstance(expr.struct_info.shape.struct_info, ShapeStructInfo)
+ assert expr.struct_info.shape.struct_info.values is not None
+
+ _check_tensor(_expr)
+ self._expr = _expr
+
+ @staticmethod
+ def from_const(data) -> "Tensor":
+ """Construct a tensor from numpy constants."""
+ return Tensor(_expr=rx.const(data))
+
+ @staticmethod
+ def from_scalar(data: Union[int, float], dtype: str) -> "Tensor":
+ """Construct a tensor from a scalar with dtype specified."""
+ return Tensor(_expr=rx.const(data, dtype=dtype))
+
+ @property
+ def shape(self) -> List[Union[int, tir.PrimExpr]]:
+ """Returns the shape of the tensor as a list of integers.
+
+ An integer can be a python int or tvm.tir.PrimExpr, depending on whether the shape is
+ fully static, for example, [1, 2, tvm.tir.Var("n")] is a valid shape where the last
+ dimension is dynamic while the first two dimensions are always static constants.
+
+ Returns
+ -------
+ shape : List[Union[int, tir.PrimExpr]]
+ The shape of the tensor
+ """
+
+ def _simplify(expr: tir.PrimExpr):
+ return expr.value if isinstance(expr, tir.IntImm) else expr
+
+ shape_sinfo: ShapeStructInfo = self._expr.struct_info.shape.struct_info
+ return [_simplify(x) for x in shape_sinfo.values]
+
+ @property
+ def ndim(self) -> int:
+ """Returns the number of dimensions of the tensor.
+
+ Returns
+ -------
+ ndim : int
+ The number of dimensions of the tensor
+ """
+ return self._expr.struct_info.ndim
+
+ @property
+ def dtype(self) -> str:
+ """Returns the data type of the tensor.
+
+ Returns
+ -------
+ dtype : str
+ The data type of the tensor
+ """
+ return self._expr.struct_info.dtype
+
+ def __repr__(self) -> str:
+ return f'Tensor({self.shape}, "{self.dtype}")'
+
+
+class Parameter(Tensor):
+ """A parameter represents the weight of a neural network layer. It is a special tensor which
+ could be bound or not bound to concrete values. If a parameter is bound to a concrete value,
+ it is called a bound parameter, otherwise it is called an unbound parameter.
+ """
+
+ _data: Optional[NDArray]
+
+ def __init__(self, shape: List[Union[int, tir.PrimExpr]], dtype: Optional[str] = None) -> None:
+ """Create a parameter with given shape and dtype. The parameter is not bound to any
+ concrete values.
+
+ Parameters
+ ----------
+ shape : List[Union[int, tir.PrimExpr]]
+ The shape of the parameter
+ dtype : Optional[str]
+ The data type of the parameter. If not specified, the default dtype will be used.
+ """
+ if dtype is None:
+ dtype = get_default_dtype()
+ super().__init__(_expr=_tensor_placeholder("param", shape, dtype=dtype)._expr)
+ self._data = None
+
+ @property
+ def data(self) -> Optional[NDArray]:
+ """Returns the concrete value of the parameter if it is bound to a concrete value,
+ otherwise returns None. The returned value is a tvm.runtime.NDArray."""
+ return self._data
+
+ @data.setter
+ def data(self, data: Union[None, NDArray, np.ndarray, "torch.Tensor"]) -> None:
+ """Set the concrete value of the parameter. The data should be one of the following:
+ - None: unbind the parameter to concrete values
+ - tvm.runtime.NDArray
+ - numpy.ndarray
+ - torch.Tensor and any other DLPack-compliant tensors
+ """
+ if data is None:
+ self._data = data
+ return
+ # Try to do zero-copy if possible
+ if isinstance(data, NDArray):
+ pass
+ elif isinstance(data, np.ndarray):
+ data = ndarray.array(data)
+ elif hasattr(data, "__dlpack__"):
+ data = _from_dlpack(data)
+ else:
+ raise TypeError(f"Unsupported data type: {type(data)}")
+ if data.shape != tuple(self.shape):
+ raise ValueError(f"Shape mismatch: expected {tuple(self.shape)}, got {data.shape}")
+ if data.dtype != self.dtype:
+ raise ValueError(f"Dtype mismatch: expected {self.dtype}, got {data.dtype}")
+ self._data = data
+
+ def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-name
+ """Change the dtype of the parameter if it is not bound to any concrete data"""
+ if dtype is not None and self._data is not None:
+ if self._data is not None:
+ raise ValueError(
+ "Changing the dtype of a Parameter that has been bound to concrete "
+ "data is not recommended. It might lead to potential precision loss "
+ "or other unexpected behaviors"
+ )
+ self._expr = _tensor_placeholder( # pylint: disable=protected-access
+ "param", self.shape, dtype=dtype
+ )._expr
+
+
+class Effect:
+ """Effect is a special non-user facing type that is used to represent operations with side
+ effects, for example, print. It is used to represent the output of a computation.
+ """
+
+ def emit_init(self, name_hint: str, builder: BlockBuilder) -> List[rx.DataflowVar]:
+ """Emit the initialization of the effect. This method is called by the compiler to
+ initialize the effect."""
+ raise NotImplementedError
+
+ def create(self, name_hint: str) -> List[rx.Var]:
+ """Create the implicit inputs to a relax.Function that represents the side effect"""
+ raise NotImplementedError
+
+ def finalize(self) -> List[rx.Var]:
+ """finalize the effect as the implicit return value of a relax.Function"""
+ raise NotImplementedError
+
+ def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-name
+ """Convert the effect to specific dtype. Usually it is no-op for most of the effects"""
+
+
+class Module:
+ """Base class for neural network components. Subclass it to build your models.
+ Modules can nest within each other in a tree structure using regular attribute assignment."""
+
+ def named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
+ """This method provides an iterator over module parameters,
+ yielding both the parameter name and its corresponding value.
+
+ Parameters
+ ----------
+ prefix : str
+ Prefix to prepend to all parameter names.
+
+ Yields
+ ------
+ (str, Parameter) – Tuple containing the name and parameter
+ """
+ yield from _attribute_finder(
+ self, prefix, condition_yield=lambda x: isinstance(x, Parameter)
+ )
+
+ def state_dict(
+ self, *, prefix: str = "", destination: Optional[Dict[str, Parameter]] = None
+ ) -> Dict[str, Parameter]:
+ """Returns a dictionary containing references to the whole state of the module.
+
+ Parameters
+ ----------
+ prefix : str
+ Prefix to prepend to all parameter names.
+ destination : Optional[Dict[str, Parameter]]
+ Dictionary to which state will be saved. If None, a new dictionary is created.
+
+ Returns
+ -------
+ dict : Dict[str, Parameter]
+ a dictionary containing a whole state of the module
+ """
+ if destination is None:
+ destination = OrderedDict()
+ for name, param in _attribute_finder(
+ self, prefix, condition_yield=lambda x: isinstance(x, Parameter)
+ ):
+ destination[name] = param
+ return destination
+
+ def load_state_dict(
+ self, state_dict: Dict[str, Parameter], strict: bool = True
+ ) -> Tuple[List[str], List[str]]:
+ """This function copies parameters and buffers from the state_dict into the current module
+ and its descendants. If `strict` is set to True, the keys in the `state_dict` must exactly
+ match the keys returned by the `state_dict()` function of this module.
+
+ Parameters
+ ----------
+ state_dict : Dict[str, Parameter]
+ A dictionary containing a whole state of the module
+ strict : bool = True
+ Whether to strictly enforce that the keys in `state_dict` match the keys returned by
+ this module's `state_dict()` function.
+
+ Returns
+ -------
+ (missing_keys, unexpected_keys) : Tuple[List[str], List[str]]
+ A tuple of two lists: the missing keys and the unexpected keys.
+ """
+ self_state_dict = self.state_dict()
+ missing_keys: List[str] = []
+ unexpected_keys: List[str] = []
+ for key, value in state_dict.items():
+ if key not in self_state_dict:
+ unexpected_keys.append(key)
+ continue
+ if value.data is None:
+ raise ValueError(f"Parameter {key} is not set to any concrete tensor")
+ self_state_dict.pop(key).data = value.data
+ missing_keys = list(self_state_dict.keys())
+ if strict and (missing_keys or unexpected_keys):
+ raise KeyError(f"Missing keys: {missing_keys}, Unexpected keys: {unexpected_keys}")
+ return missing_keys, unexpected_keys
+
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ """Call the module with the given inputs and returns the output."""
+ if not hasattr(self, "forward"):
+ raise NotImplementedError(f"Module {type(self)} does not have a `forward` method")
+ return self.forward(*args, **kwargs) # pylint: disable=no-member
+
+ def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-name
+ """Convert the module to specific dtype recursively"""
+ for _, item in self.__dict__.items():
+ if hasattr(item, "to") and callable(item.to):
+ item.to(dtype=dtype)
+
+
+class ModuleList(Module):
+ """Holds submodules in a list."""
+
+ def __init__(self, modules: List[Module]):
+ self.modules = modules
+
+ def __iter__(self):
+ return iter(self.modules)
+
+ def __getitem__(self, idx):
+ return self.modules[idx]
+
+ def __setitem__(self, idx, module):
+ self.modules[idx] = module
+
+ def __len__(self):
+ return len(self.modules)
+
+ def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-name
+ for module in self.modules:
+ module.to(dtype=dtype)
+
+ def forward(self, x): # pylint: disable=invalid-name
+ """Feed-forward pass of the module"""
+ for module in self.modules:
+ x = module(x)
+ return x
+
+
+def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any], bool]):
+ """Find attributes that satisfy the condition recursively"""
+ for name, item in root.__dict__.items():
+ if condition_yield(item):
+ yield prefix + name, item
+ elif isinstance(item, ModuleList):
+ for i, subitem in enumerate(item):
+ yield from _attribute_finder(
+ subitem,
+ prefix + name + f".{i}.",
+ condition_yield,
+ )
+ elif isinstance(item, Module):
+ yield from _attribute_finder(
+ item,
+ prefix + name + ".",
+ condition_yield,
+ )
+
+
+def _tensor_placeholder(
+ name: str, shape: Sequence[Union[int, tir.PrimExpr]], dtype: str
+) -> "Tensor":
+ new_shape = []
+ for expr in shape:
+ if isinstance(expr, (int, tir.IntImm)):
+ expr = int(expr)
+ assert expr >= 0
+ new_shape.append(expr)
+ continue
+ if not isinstance(expr, tir.PrimExpr):
+ raise TypeError(f"Invalid shape: {shape}")
+ assert expr.dtype == "int64"
+ new_shape.append(expr)
+ return Tensor(
+ _expr=rx.Var(
+ name_hint=name,
+ struct_info=TensorStructInfo(
+ shape=new_shape,
+ dtype=dtype,
+ ),
+ )
+ )
+
+
+def _from_dlpack(tensor) -> NDArray:
+ try:
+ return ndarray.from_dlpack(tensor)
+ except RuntimeError:
+ pass
+ # special logic for PyTorch
+ device_type = tensor.device.type
+ device_id = tensor.device.index or 0
+ return ndarray.array(
+ tensor.numpy(),
+ device=Device(
+ Device.STR2MASK[device_type],
+ device_id,
+ ),
+ )