You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "Ubospica (via GitHub)" <gi...@apache.org> on 2023/07/04 23:44:26 UTC

[GitHub] [tvm] Ubospica opened a new pull request, #15231: [Unity][Training] Registering te gradient

Ubospica opened a new pull request, #15231:
URL: https://github.com/apache/tvm/pull/15231

   This PR supports registering te gradient functions. The Gradient pass will call te gradient functions when meet call_tir nodes in forward code.
   
   Current workflow is as follows:
   ```
   # register te gradient function
   @register_te_gradient("f_mul_grad")
   def f_mul_grad(output_grad: te.Tensor, src1: te.Tensor, src2: te.Tensor, k: int):
       # returns a list of te tensors, representing gradients w.r.t. src1, src2
       # k is a constant parameter
       ...
   
   # irmodule definition
   @I.ir_module
   class Module:
       @T.prim_func
       def f_mul(A, B, result):
           ...
   
       @R.function
       def main(a, b):
           cls = Module
           with R.dataflow():
               lv = R.call_tir(cls.f_mul, (a, b), te_grad_name="f_mul_grad", te_grad_kwargs={"k": 1})
               gv = R.output(lv)
           return gv
   ```
   
   It's worth to note this PR defines an attribute to the call_tir node:
   ```
   struct CallTIRAttrs : public tvm::AttrsNode<CallTIRAttrs> {
     Optional<String> te_grad_name;
     Map<String, ObjectRef> te_grad_kwargs;
   
     TVM_DECLARE_ATTRS(CallTIRAttrs, "relax.attrs.CallTIRAttrs") {
       TVM_ATTR_FIELD(te_grad_name)
           .describe("The name of the te gradient function associated with this call_tir node.");
       TVM_ATTR_FIELD(te_grad_kwargs)
           .describe("The keyword arguments passed to the te gradient function.");
     }
   };  // struct CallTIRAttrs
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #15231: [Unity][Training] Registering te gradient

Posted by "tvm-bot (via GitHub)" <gi...@apache.org>.
tvm-bot commented on PR #15231:
URL: https://github.com/apache/tvm/pull/15231#issuecomment-1620834919

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen commented on pull request #15231: [Unity][Training] Registering te gradient

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen commented on PR #15231:
URL: https://github.com/apache/tvm/pull/15231#issuecomment-1627471609

   Thanks @Ubospica . let us introduce a new intrinsic `call_tir_with_grad`, given this is a temp work around before tir AD.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tqchen merged pull request #15231: [Unity][Training] Registering te gradient

Posted by "tqchen (via GitHub)" <gi...@apache.org>.
tqchen merged PR #15231:
URL: https://github.com/apache/tvm/pull/15231


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org