lidavidm commented on code in PR #36009:
URL: https://github.com/apache/arrow/pull/36009#discussion_r1235641556


##########
cpp/src/arrow/flight/integration_tests/test_integration.cc:
##########
@@ -410,6 +413,470 @@ class OrderedScenario : public Scenario {
   }
 };
 
+/// \brief The server used for testing FlightEndpoint.expiration_time.
+///
+/// GetFlightInfo() returns a FlightInfo that has the following
+/// three FlightEndpoints:
+///
+/// 1. No expiration time
+/// 2. 2 seconds expiration time
+/// 3. 3 seconds expiration time
+///
+/// The client can't read data from the first endpoint multiple times
+/// but can read data from the second and third endpoints. The client
+/// can't re-read data from the second endpoint 2 seconds later. The
+/// client can't re-read data from the third endpoint 3 seconds
+/// later.
+///
+/// The client can cancel a returned FlightInfo by pre-defined
+/// CancelFlightInfo action. The client can't read data from endpoints
+/// even within 3 seconds after the action.
+///
+/// The client can extend the expiration time of a FlightEndpoint in
+/// a returned FlightInfo by pre-defined RefreshFlightEndpoint
+/// action. The client can read data from endpoints multiple times
+/// within more 10 seconds after the action.
+///
+/// The client can close a returned FlightInfo explicitly by
+/// pre-defined CloseFlightInfo action. The client can't read data
+/// from endpoints even within 3 seconds after the action.
+class ExpirationTimeServer : public FlightServerBase {
+ private:
+  struct EndpointStatus {
+    explicit EndpointStatus(std::optional<Timestamp> expiration_time)
+        : expiration_time(expiration_time) {}
+
+    std::optional<Timestamp> expiration_time;
+    uint32_t num_gets = 0;
+    bool cancelled = false;
+    bool closed = false;
+  };
+
+ public:
+  ExpirationTimeServer() : FlightServerBase(), statuses_() {}
+
+  Status GetFlightInfo(const ServerCallContext& context,
+                       const FlightDescriptor& descriptor,
+                       std::unique_ptr<FlightInfo>* result) override {
+    statuses_.clear();
+    auto schema = BuildSchema();
+    std::vector<FlightEndpoint> endpoints;
+    AddEndpoint(endpoints, "No expiration time", std::nullopt);
+    AddEndpoint(endpoints, "2 seconds",
+                Timestamp::clock::now() + std::chrono::seconds{2});
+    AddEndpoint(endpoints, "3 seconds",
+                Timestamp::clock::now() + std::chrono::seconds{3});
+    ARROW_ASSIGN_OR_RAISE(
+        auto info, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1, 
false));
+    *result = std::make_unique<FlightInfo>(info);
+    return Status::OK();
+  }
+
+  Status DoGet(const ServerCallContext& context, const Ticket& request,
+               std::unique_ptr<FlightDataStream>* stream) override {
+    ARROW_ASSIGN_OR_RAISE(auto index, ExtractIndexFromTicket(request.ticket));
+    auto& status = statuses_[index];
+    if (status.closed) {
+      return Status::KeyError("Invalid flight: closed: ", request.ticket);
+    }
+    if (status.cancelled) {
+      return Status::KeyError("Invalid flight: canceled: ", request.ticket);
+    }
+    if (status.expiration_time.has_value()) {
+      auto expiration_time = status.expiration_time.value();
+      if (expiration_time < Timestamp::clock::now()) {
+        return Status::KeyError("Invalid flight: expired: ", request.ticket);
+      }
+    } else {
+      if (status.num_gets > 0) {
+        return Status::KeyError("Invalid flight: can't read multiple times: ",
+                                request.ticket);
+      }
+    }
+    status.num_gets++;
+    ARROW_ASSIGN_OR_RAISE(auto builder, RecordBatchBuilder::Make(
+                                            BuildSchema(), 
arrow::default_memory_pool()));
+    auto number_builder = builder->GetFieldAs<UInt32Builder>(0);
+    ARROW_RETURN_NOT_OK(number_builder->Append(index));
+    ARROW_ASSIGN_OR_RAISE(auto record_batch, builder->Flush());
+    std::vector<std::shared_ptr<RecordBatch>> record_batches{record_batch};
+    ARROW_ASSIGN_OR_RAISE(auto record_batch_reader,
+                          RecordBatchReader::Make(record_batches));
+    *stream = std::make_unique<RecordBatchStream>(record_batch_reader);
+    return Status::OK();
+  }
+
+  Status DoAction(const ServerCallContext& context, const Action& action,
+                  std::unique_ptr<ResultStream>* result_stream) override {
+    std::vector<Result> results;
+    if (action.type == ActionType::kCancelFlightInfo.type) {
+      ARROW_ASSIGN_OR_RAISE(auto info,
+                            
FlightInfo::Deserialize(std::string_view(*action.body)));
+      for (const auto& endpoint : info->endpoints()) {
+        auto index_result = ExtractIndexFromTicket(endpoint.ticket.ticket);
+        auto cancel_status = CancelStatus::kUnspecified;
+        if (index_result.ok()) {
+          auto index = *index_result;
+          if (statuses_[index].cancelled) {
+            cancel_status = CancelStatus::kNotCancellable;
+          } else {
+            statuses_[index].cancelled = true;
+            cancel_status = CancelStatus::kCancelled;
+          }
+        } else {
+          cancel_status = CancelStatus::kNotCancellable;
+        }
+        auto cancel_result = CancelFlightInfoResult{cancel_status};
+        ARROW_ASSIGN_OR_RAISE(auto serialized, 
cancel_result.SerializeToString());
+        results.push_back(Result{Buffer::FromString(std::move(serialized))});
+      }
+    } else if (action.type == ActionType::kCloseFlightInfo.type) {
+      ARROW_ASSIGN_OR_RAISE(auto info,
+                            
FlightInfo::Deserialize(std::string_view(*action.body)));
+      for (const auto& endpoint : info->endpoints()) {
+        auto index_result = ExtractIndexFromTicket(endpoint.ticket.ticket);
+        if (!index_result.ok()) {
+          continue;
+        }
+        auto index = *index_result;
+        statuses_[index].closed = true;
+      }
+    } else if (action.type == ActionType::kRefreshFlightEndpoint.type) {
+      ARROW_ASSIGN_OR_RAISE(auto endpoint,
+                            
FlightEndpoint::Deserialize(std::string_view(*action.body)));
+      ARROW_ASSIGN_OR_RAISE(auto index, 
ExtractIndexFromTicket(endpoint.ticket.ticket));
+      if (statuses_[index].cancelled) {
+        return Status::Invalid("Invalid flight: canceled: ", 
endpoint.ticket.ticket);
+      }
+      endpoint.ticket.ticket += ": refreshed (+ 10 seconds)";
+      endpoint.expiration_time = Timestamp::clock::now() + 
std::chrono::seconds{10};
+      statuses_[index].expiration_time = endpoint.expiration_time.value();
+      ARROW_ASSIGN_OR_RAISE(auto serialized, endpoint.SerializeToString());
+      results.push_back(Result{Buffer::FromString(std::move(serialized))});
+    } else {
+      return Status::Invalid("Unknown action: ", action.type);
+    }
+    *result_stream = std::make_unique<SimpleResultStream>(std::move(results));
+    return Status::OK();
+  }
+
+  Status ListActions(const ServerCallContext& context,
+                     std::vector<ActionType>* actions) override {
+    *actions = {
+        ActionType::kCancelFlightInfo,
+        ActionType::kCloseFlightInfo,
+        ActionType::kRefreshFlightEndpoint,
+    };
+    return Status::OK();
+  }
+
+ private:
+  void AddEndpoint(std::vector<FlightEndpoint>& endpoints, std::string ticket,
+                   std::optional<Timestamp> expiration_time) {
+    endpoints.push_back(FlightEndpoint{
+        {std::to_string(statuses_.size()) + ": " + ticket}, {}, 
expiration_time});
+    statuses_.emplace_back(expiration_time);
+  }
+
+  arrow::Result<uint32_t> ExtractIndexFromTicket(const std::string& ticket) {
+    auto index_string = arrow::internal::SplitString(ticket, ':', 2)[0];
+    uint32_t index;
+    if (!arrow::internal::ParseUnsigned(index_string.data(), 
index_string.length(),
+                                        &index)) {
+      return Status::KeyError("Invalid flight: no index: ", ticket);
+    }
+    if (index >= statuses_.size()) {
+      return Status::KeyError("Invalid flight: out of index: ", ticket);
+    }
+    return index;
+  }
+
+  std::shared_ptr<Schema> BuildSchema() {
+    return arrow::schema({arrow::field("number", arrow::uint32(), false)});
+  }
+
+  std::vector<EndpointStatus> statuses_;
+};
+
+/// \brief The expiration time scenario - DoGet.
+///
+/// This tests that the client can read data that isn't expired yet
+/// multiple times and can't read data after it's expired.
+class ExpirationTimeDoGetScenario : public Scenario {
+  Status MakeServer(std::unique_ptr<FlightServerBase>* server,
+                    FlightServerOptions* options) override {
+    *server = std::make_unique<ExpirationTimeServer>();
+    return Status::OK();
+  }
+
+  Status MakeClient(FlightClientOptions* options) override { return 
Status::OK(); }
+
+  Status RunClient(std::unique_ptr<FlightClient> client) override {
+    ARROW_ASSIGN_OR_RAISE(
+        auto info, 
client->GetFlightInfo(FlightDescriptor::Command("expiration_time")));
+    std::vector<std::shared_ptr<arrow::Table>> tables;
+    // First read from all endpoints
+    for (const auto& endpoint : info->endpoints()) {
+      ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(endpoint.ticket));
+      ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable());
+      tables.push_back(table);
+    }
+    // Re-reads only from endpoints that have expiration time
+    for (const auto& endpoint : info->endpoints()) {
+      if (endpoint.expiration_time.has_value()) {
+        ARROW_ASSIGN_OR_RAISE(auto reader, client->DoGet(endpoint.ticket));
+        ARROW_ASSIGN_OR_RAISE(auto table, reader->ToTable());
+        tables.push_back(table);
+      } else {
+        auto reader = client->DoGet(endpoint.ticket);
+        if (reader.ok()) {
+          return Status::Invalid(
+              "Data that doesn't have expiration time "
+              "shouldn't be readable multiple times");
+        }
+      }
+    }
+    // Re-reads after expired
+    for (const auto& endpoint : info->endpoints()) {
+      if (!endpoint.expiration_time.has_value()) {
+        continue;
+      }
+      const auto& expiration_time = endpoint.expiration_time.value();
+      if (expiration_time > Timestamp::clock::now()) {
+        std::this_thread::sleep_for(expiration_time - Timestamp::clock::now());
+      }

Review Comment:
   I wonder if this might be flaky in CI (at least in the prior loop), since if 
it takes more than 2 seconds somehow it'll fail above. Maybe there's no need to 
test the actual expiration of the ticket? Or, we can treat the expiration as a 
counter instead of a real expiration time to make things not dependent on the 
clock.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to