This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new d831230  THRIFT-5326: Expand TException interface in go library
d831230 is described below

commit d831230929bb332189c9509d07102e4be9e7f681
Author: Yuxuan 'fishy' Wang <[email protected]>
AuthorDate: Tue Dec 22 09:53:58 2020 -0800

    THRIFT-5326: Expand TException interface in go library
    
    Client: go
    
    Add TExceptionType enum type, and add
    
        TExceptionType() TExceptionType
    
    function to TException definition.
    
    Also make TProtocolException unwrap-able.
---
 CHANGES.md                                         |  1 +
 compiler/cpp/src/thrift/generate/t_go_generator.cc | 33 +++++----
 lib/go/test/tests/thrifttest_handler.go            |  2 +-
 lib/go/thrift/application_exception.go             |  6 ++
 lib/go/thrift/compact_protocol.go                  |  2 +-
 lib/go/thrift/exception.go                         | 81 ++++++++++++++++++++--
 lib/go/thrift/multiplexed_protocol.go              | 14 ++--
 lib/go/thrift/protocol_exception.go                | 33 +++++++--
 lib/go/thrift/simple_server.go                     | 18 ++++-
 lib/go/thrift/transport_exception.go               |  6 ++
 lib/go/thrift/transport_exception_test.go          |  4 --
 11 files changed, 162 insertions(+), 38 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 663c4c1..8e4d08e 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -17,6 +17,7 @@
 - [THRIFT-5233](https://issues.apache.org/jira/browse/THRIFT-5233) - go: Now 
all Read*, Write* and Skip functions in TProtocol accept context arg
 - [THRIFT-5152](https://issues.apache.org/jira/browse/THRIFT-5152) - go: 
TSocket and TSSLSocket now have separated connect timeout and socket timeout
 - c++: dropped support for Windows XP
+- [THRIFT-5326](https://issues.apache.org/jira/browse/THRIFT-5326) - go: 
TException interface now has a new function: TExceptionType
 
 ### Java
 
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc 
b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index 3bb2a5c..49d8bc1 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -1493,8 +1493,15 @@ void 
t_go_generator::generate_go_struct_definition(ostream& out,
 
   if (is_exception) {
     out << indent() << "func (p *" << tstruct_name << ") Error() string {" << 
endl;
-    out << indent() << "  return p.String()" << endl;
+    out << indent() << indent() << "return p.String()" << endl;
     out << indent() << "}" << endl << endl;
+
+    out << indent() << "func (" << tstruct_name << ") TExceptionType() 
thrift.TExceptionType {" << endl;
+    out << indent() << indent() << "return thrift.TExceptionTypeCompiled" << 
endl;
+    out << indent() << "}" << endl << endl;
+
+    out << indent() << "var _ thrift.TException = (*" << tstruct_name << 
")(nil)"
+        << endl << endl;
   }
 }
 
@@ -2700,8 +2707,8 @@ void t_go_generator::generate_service_server(t_service* 
tservice) {
     f_types_ << indent() << "func (p *" << serviceName
                << "Processor) Process(ctx context.Context, iprot, oprot 
thrift.TProtocol) (success bool, err "
                   "thrift.TException) {" << endl;
-    f_types_ << indent() << "  name, _, seqId, err := 
iprot.ReadMessageBegin(ctx)" << endl;
-    f_types_ << indent() << "  if err != nil { return false, err }" << endl;
+    f_types_ << indent() << "  name, _, seqId, err2 := 
iprot.ReadMessageBegin(ctx)" << endl;
+    f_types_ << indent() << "  if err2 != nil { return false, 
thrift.WrapTException(err2) }" << endl;
     f_types_ << indent() << "  if processor, ok := 
p.GetProcessorFunction(name); ok {" << endl;
     f_types_ << indent() << "    return processor.Process(ctx, seqId, iprot, 
oprot)" << endl;
     f_types_ << indent() << "  }" << endl;
@@ -2767,11 +2774,12 @@ void 
t_go_generator::generate_process_function(t_service* tservice, t_function*
                 "thrift.TException) {" << endl;
   indent_up();
   f_types_ << indent() << "args := " << argsname << "{}" << endl;
-  f_types_ << indent() << "if err = args." << read_method_name_ <<  "(ctx, 
iprot); err != nil {" << endl;
+  f_types_ << indent() << "var err2 error" << endl;
+  f_types_ << indent() << "if err2 = args." << read_method_name_ <<  "(ctx, 
iprot); err2 != nil {" << endl;
   f_types_ << indent() << "  iprot.ReadMessageEnd(ctx)" << endl;
   if (!tfunction->is_oneway()) {
     f_types_ << indent()
-               << "  x := 
thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err.Error())"
+               << "  x := 
thrift.NewTApplicationException(thrift.PROTOCOL_ERROR, err2.Error())"
                << endl;
     f_types_ << indent() << "  oprot.WriteMessageBegin(ctx, \"" << 
escape_string(tfunction->get_name())
                << "\", thrift.EXCEPTION, seqId)" << endl;
@@ -2779,7 +2787,7 @@ void t_go_generator::generate_process_function(t_service* 
tservice, t_function*
     f_types_ << indent() << "  oprot.WriteMessageEnd(ctx)" << endl;
     f_types_ << indent() << "  oprot.Flush(ctx)" << endl;
   }
-  f_types_ << indent() << "  return false, err" << endl;
+  f_types_ << indent() << "  return false, thrift.WrapTException(err2)" << 
endl;
   f_types_ << indent() << "}" << endl;
   f_types_ << indent() << "iprot.ReadMessageEnd(ctx)" << endl << endl;
 
@@ -2842,7 +2850,6 @@ void t_go_generator::generate_process_function(t_service* 
tservice, t_function*
     f_types_ << indent() << "var retval " << 
type_to_go_type(tfunction->get_returntype()) << endl;
   }
 
-  f_types_ << indent() << "var err2 error" << endl;
   f_types_ << indent() << "if ";
 
   if (!tfunction->is_oneway()) {
@@ -2892,7 +2899,7 @@ void t_go_generator::generate_process_function(t_service* 
tservice, t_function*
   if (!tfunction->is_oneway()) {
     // Avoid writing the error to the wire if it's ErrAbandonRequest
     f_types_ << indent() << "  if err2 == thrift.ErrAbandonRequest {" << endl;
-    f_types_ << indent() << "    return false, err2" << endl;
+    f_types_ << indent() << "    return false, thrift.WrapTException(err2)" << 
endl;
     f_types_ << indent() << "  }" << endl;
 
     f_types_ << indent() << "  x := 
thrift.NewTApplicationException(thrift.INTERNAL_ERROR, "
@@ -2905,7 +2912,7 @@ void t_go_generator::generate_process_function(t_service* 
tservice, t_function*
     f_types_ << indent() << "  oprot.Flush(ctx)" << endl;
   }
 
-  f_types_ << indent() << "  return true, err2" << endl;
+  f_types_ << indent() << "  return true, thrift.WrapTException(err2)" << endl;
 
   if (!x_fields.empty()) {
     f_types_ << indent() << "}" << endl;
@@ -2931,17 +2938,17 @@ void 
t_go_generator::generate_process_function(t_service* tservice, t_function*
     f_types_ << indent() << "if err2 = oprot.WriteMessageBegin(ctx, \""
                << escape_string(tfunction->get_name()) << "\", thrift.REPLY, 
seqId); err2 != nil {"
                << endl;
-    f_types_ << indent() << "  err = err2" << endl;
+    f_types_ << indent() << "  err = thrift.WrapTException(err2)" << endl;
     f_types_ << indent() << "}" << endl;
     f_types_ << indent() << "if err2 = result." << write_method_name_ << 
"(ctx, oprot); err == nil && err2 != nil {" << endl;
-    f_types_ << indent() << "  err = err2" << endl;
+    f_types_ << indent() << "  err = thrift.WrapTException(err2)" << endl;
     f_types_ << indent() << "}" << endl;
     f_types_ << indent() << "if err2 = oprot.WriteMessageEnd(ctx); err == nil 
&& err2 != nil {"
                << endl;
-    f_types_ << indent() << "  err = err2" << endl;
+    f_types_ << indent() << "  err = thrift.WrapTException(err2)" << endl;
     f_types_ << indent() << "}" << endl;
     f_types_ << indent() << "if err2 = oprot.Flush(ctx); err == nil && err2 != 
nil {" << endl;
-    f_types_ << indent() << "  err = err2" << endl;
+    f_types_ << indent() << "  err = thrift.WrapTException(err2)" << endl;
     f_types_ << indent() << "}" << endl;
     f_types_ << indent() << "if err != nil {" << endl;
     f_types_ << indent() << "  return" << endl;
diff --git a/lib/go/test/tests/thrifttest_handler.go 
b/lib/go/test/tests/thrifttest_handler.go
index 31b9ee2..7b115ec 100644
--- a/lib/go/test/tests/thrifttest_handler.go
+++ b/lib/go/test/tests/thrifttest_handler.go
@@ -179,7 +179,7 @@ func (p *ThriftTestHandler) TestException(ctx 
context.Context, arg string) (err
                x.Message = arg
                return x
        } else if arg == "TException" {
-               return thrift.TException(errors.New(arg))
+               return thrift.WrapTException(errors.New(arg))
        } else {
                return nil
        }
diff --git a/lib/go/thrift/application_exception.go 
b/lib/go/thrift/application_exception.go
index 6de37ee..32d5b01 100644
--- a/lib/go/thrift/application_exception.go
+++ b/lib/go/thrift/application_exception.go
@@ -64,6 +64,12 @@ type tApplicationException struct {
        type_   int32
 }
 
+var _ TApplicationException = (*tApplicationException)(nil)
+
+func (tApplicationException) TExceptionType() TExceptionType {
+       return TExceptionTypeApplication
+}
+
 func (e tApplicationException) Error() string {
        if e.message != "" {
                return e.message
diff --git a/lib/go/thrift/compact_protocol.go 
b/lib/go/thrift/compact_protocol.go
index 25e6d0c..a49225d 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -845,7 +845,7 @@ func (p *TCompactProtocol) getTType(t tCompactType) (TType, 
error) {
        case COMPACT_STRUCT:
                return STRUCT, nil
        }
-       return STOP, TException(fmt.Errorf("don't know what type: %v", t&0x0f))
+       return STOP, NewTProtocolException(fmt.Errorf("don't know what type: 
%v", t&0x0f))
 }
 
 // Given a TType value, find the appropriate TCompactProtocol.Types constant.
diff --git a/lib/go/thrift/exception.go b/lib/go/thrift/exception.go
index ea8d6f6..b6885fa 100644
--- a/lib/go/thrift/exception.go
+++ b/lib/go/thrift/exception.go
@@ -26,19 +26,86 @@ import (
 // Generic Thrift exception
 type TException interface {
        error
+
+       TExceptionType() TExceptionType
 }
 
 // Prepends additional information to an error without losing the Thrift 
exception interface
 func PrependError(prepend string, err error) error {
-       if t, ok := err.(TTransportException); ok {
-               return NewTTransportException(t.TypeId(), prepend+t.Error())
+       msg := prepend + err.Error()
+
+       if te, ok := err.(TException); ok {
+               switch te.TExceptionType() {
+               case TExceptionTypeTransport:
+                       if t, ok := err.(TTransportException); ok {
+                               return NewTTransportException(t.TypeId(), msg)
+                       }
+               case TExceptionTypeProtocol:
+                       if t, ok := err.(TProtocolException); ok {
+                               return 
NewTProtocolExceptionWithType(t.TypeId(), errors.New(msg))
+                       }
+               case TExceptionTypeApplication:
+                       if t, ok := err.(TApplicationException); ok {
+                               return NewTApplicationException(t.TypeId(), msg)
+                       }
+               }
+
+               return wrappedTException{
+                       err:            errors.New(msg),
+                       tExceptionType: te.TExceptionType(),
+               }
+       }
+
+       return errors.New(msg)
+}
+
+// TExceptionType is an enum type to categorize different "subclasses" of 
TExceptions.
+type TExceptionType byte
+
+// TExceptionType values
+const (
+       TExceptionTypeUnknown     TExceptionType = iota
+       TExceptionTypeCompiled                   // TExceptions defined in 
thrift files and generated by thrift compiler
+       TExceptionTypeApplication                // TApplicationExceptions
+       TExceptionTypeProtocol                   // TProtocolExceptions
+       TExceptionTypeTransport                  // TTransportExceptions
+)
+
+// WrapTException wraps an error into TException.
+//
+// If err is nil or already TException, it's returned as-is.
+// Otherwise it will be wraped into TException with TExceptionType() returning
+// TExceptionTypeUnknown, and Unwrap() returning the original error.
+func WrapTException(err error) TException {
+       if err == nil {
+               return nil
        }
-       if t, ok := err.(TProtocolException); ok {
-               return NewTProtocolExceptionWithType(t.TypeId(), 
errors.New(prepend+err.Error()))
+
+       if te, ok := err.(TException); ok {
+               return te
        }
-       if t, ok := err.(TApplicationException); ok {
-               return NewTApplicationException(t.TypeId(), prepend+t.Error())
+
+       return wrappedTException{
+               err:            err,
+               tExceptionType: TExceptionTypeUnknown,
        }
+}
+
+type wrappedTException struct {
+       err            error
+       tExceptionType TExceptionType
+}
 
-       return errors.New(prepend + err.Error())
+func (w wrappedTException) Error() string {
+       return w.err.Error()
 }
+
+func (w wrappedTException) TExceptionType() TExceptionType {
+       return w.tExceptionType
+}
+
+func (w wrappedTException) Unwrap() error {
+       return w.err
+}
+
+var _ TException = wrappedTException{}
diff --git a/lib/go/thrift/multiplexed_protocol.go 
b/lib/go/thrift/multiplexed_protocol.go
index 2f7997e..cacbf6b 100644
--- a/lib/go/thrift/multiplexed_protocol.go
+++ b/lib/go/thrift/multiplexed_protocol.go
@@ -192,10 +192,10 @@ func (t *TMultiplexedProcessor) RegisterProcessor(name 
string, processor TProces
 func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out 
TProtocol) (bool, TException) {
        name, typeId, seqid, err := in.ReadMessageBegin(ctx)
        if err != nil {
-               return false, err
+               return false, NewTProtocolException(err)
        }
        if typeId != CALL && typeId != ONEWAY {
-               return false, fmt.Errorf("Unexpected message type %v", typeId)
+               return false, NewTProtocolException(fmt.Errorf("Unexpected 
message type %v", typeId))
        }
        //extract the service name
        v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2)
@@ -204,11 +204,17 @@ func (t *TMultiplexedProcessor) Process(ctx 
context.Context, in, out TProtocol)
                        smb := NewStoredMessageProtocol(in, name, typeId, seqid)
                        return t.DefaultProcessor.Process(ctx, smb, out)
                }
-               return false, fmt.Errorf("Service name not found in message 
name: %s.  Did you forget to use a TMultiplexProtocol in your client?", name)
+               return false, NewTProtocolException(fmt.Errorf(
+                       "Service name not found in message name: %s.  Did you 
forget to use a TMultiplexProtocol in your client?",
+                       name,
+               ))
        }
        actualProcessor, ok := t.serviceProcessorMap[v[0]]
        if !ok {
-               return false, fmt.Errorf("Service name not found: %s.  Did you 
forget to call registerProcessor()?", v[0])
+               return false, NewTProtocolException(fmt.Errorf(
+                       "Service name not found: %s.  Did you forget to call 
registerProcessor()?",
+                       v[0],
+               ))
        }
        smb := NewStoredMessageProtocol(in, v[1], typeId, seqid)
        return actualProcessor.Process(ctx, smb, out)
diff --git a/lib/go/thrift/protocol_exception.go 
b/lib/go/thrift/protocol_exception.go
index 29ab75d..b088caf 100644
--- a/lib/go/thrift/protocol_exception.go
+++ b/lib/go/thrift/protocol_exception.go
@@ -40,8 +40,14 @@ const (
 )
 
 type tProtocolException struct {
-       typeId  int
-       message string
+       typeId int
+       err    error
+}
+
+var _ TProtocolException = (*tProtocolException)(nil)
+
+func (tProtocolException) TExceptionType() TExceptionType {
+       return TExceptionTypeProtocol
 }
 
 func (p *tProtocolException) TypeId() int {
@@ -49,11 +55,15 @@ func (p *tProtocolException) TypeId() int {
 }
 
 func (p *tProtocolException) String() string {
-       return p.message
+       return p.err.Error()
 }
 
 func (p *tProtocolException) Error() string {
-       return p.message
+       return p.err.Error()
+}
+
+func (p *tProtocolException) Unwrap() error {
+       return p.err
 }
 
 func NewTProtocolException(err error) TProtocolException {
@@ -64,14 +74,23 @@ func NewTProtocolException(err error) TProtocolException {
                return e
        }
        if _, ok := err.(base64.CorruptInputError); ok {
-               return &tProtocolException{INVALID_DATA, err.Error()}
+               return &tProtocolException{
+                       typeId: INVALID_DATA,
+                       err:    err,
+               }
+       }
+       return &tProtocolException{
+               typeId: UNKNOWN_PROTOCOL_EXCEPTION,
+               err:    err,
        }
-       return &tProtocolException{UNKNOWN_PROTOCOL_EXCEPTION, err.Error()}
 }
 
 func NewTProtocolExceptionWithType(errType int, err error) TProtocolException {
        if err == nil {
                return nil
        }
-       return &tProtocolException{errType, err.Error()}
+       return &tProtocolException{
+               typeId: errType,
+               err:    err,
+       }
 }
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index e9fea86..ca0e61d 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -315,7 +315,9 @@ func (p *TSimpleServer) processRequests(client TTransport) 
(err error) {
                }
 
                ok, err := processor.Process(ctx, inputProtocol, outputProtocol)
-               if err == ErrAbandonRequest {
+               // Once we dropped support for pre-go1.13 this can be replaced 
by:
+               // errors.Is(err, ErrAbandonRequest)
+               if unwrapError(err) == ErrAbandonRequest {
                        return client.Close()
                }
                if _, ok := err.(TTransportException); ok && err != nil {
@@ -330,3 +332,17 @@ func (p *TSimpleServer) processRequests(client TTransport) 
(err error) {
        }
        return nil
 }
+
+type unwrapper interface {
+       Unwrap() error
+}
+
+func unwrapError(err error) error {
+       for {
+               if u, ok := err.(unwrapper); ok {
+                       err = u.Unwrap()
+               } else {
+                       return err
+               }
+       }
+}
diff --git a/lib/go/thrift/transport_exception.go 
b/lib/go/thrift/transport_exception.go
index 16193ee..cf2cc00 100644
--- a/lib/go/thrift/transport_exception.go
+++ b/lib/go/thrift/transport_exception.go
@@ -48,6 +48,12 @@ type tTransportException struct {
        err    error
 }
 
+var _ TTransportException = (*tTransportException)(nil)
+
+func (tTransportException) TExceptionType() TExceptionType {
+       return TExceptionTypeTransport
+}
+
 func (p *tTransportException) TypeId() int {
        return p.typeId
 }
diff --git a/lib/go/thrift/transport_exception_test.go 
b/lib/go/thrift/transport_exception_test.go
index fb1dc26..57386cb 100644
--- a/lib/go/thrift/transport_exception_test.go
+++ b/lib/go/thrift/transport_exception_test.go
@@ -36,10 +36,6 @@ func (t *timeout) Error() string {
        return fmt.Sprintf("Timeout: %v", t.timedout)
 }
 
-type unwrapper interface {
-       Unwrap() error
-}
-
 func TestTExceptionTimeout(t *testing.T) {
        timeout := &timeout{true}
        exception := NewTTransportExceptionFromError(timeout)

Reply via email to