davesearlegsa commented on issue #3659:
URL: https://github.com/apache/arrow-adbc/issues/3659#issuecomment-3474538351

   Actually, I've got this working with a bit of reflection, I'll leave the 
code below in case it's useful to anyone else. I appreciate it's not part of 
the ADBC spec or library contract, but it would be nice if we could get access 
to the underlying CAdbcConnection.
   
   
   `internal class Program
   {
       [StructLayout(LayoutKind.Sequential)]
       public struct DuckDBAdbcConnectionWrapper
       {
           public IntPtr connection; // duckdb_connection
       }
   
       public class AsyncRecordBatchStream : IArrowArrayStream
       {
           private readonly Schema _schema;
           private readonly NativeMemoryAllocator _nativeMemoryAllocator = 
new();
           private IAsyncEnumerator<RecordBatch> _recordBatchEnumerator;
   
           public Schema Schema => _schema;
   
           public AsyncRecordBatchStream(Schema schema, 
IAsyncEnumerable<RecordBatch> recordBatches)
           {
               _schema = schema;
               _recordBatchEnumerator = recordBatches.GetAsyncEnumerator();
           }
   
           public async ValueTask<RecordBatch> 
ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
           {
               if (await _recordBatchEnumerator.MoveNextAsync())
               {
                   Console.WriteLine("ReadNextRecordBatchAsync: Processing a 
record batch.");
                   return 
_recordBatchEnumerator.Current.Clone(_nativeMemoryAllocator);
               }
   
               return null;
           }
   
           public void Dispose(){}
       }
   
       [DllImport("duckdb.dll", CallingConvention = CallingConvention.Cdecl, 
EntryPoint = "duckdb_arrow_scan")]
       public static extern DuckDBState duckdb_arrow_scan(IntPtr connection, 
string name, IntPtr arrowStream);
   
       static void Main(string[] args)
       {
           string root = Directory.GetCurrentDirectory();
           string file = Path.Combine(root, "bin\\debug\\net9.0", "duckdb.dll");
   
           AdbcDriver driver = CAdbcDriverImporter.Load(file, 
"duckdb_adbc_init");
           AdbcDatabase db = driver.Open(new Dictionary<string, string> { { 
"path", ":memory:" } });
           AdbcConnection conn = db.Connect(null);
   
           var nativeConnectionField = 
conn.GetType().GetField("_nativeConnection", BindingFlags.NonPublic | 
BindingFlags.Instance);
           var adbcConn = 
(CAdbcConnection)nativeConnectionField?.GetValue(conn);
   
           IntPtr duckdbConnection = IntPtr.Zero;
   
           unsafe
           {
               DuckDBAdbcConnectionWrapper wrapper = 
Marshal.PtrToStructure<DuckDBAdbcConnectionWrapper>((IntPtr)adbcConn.private_data);
               duckdbConnection = wrapper.connection;
           }
   
           var recordBatchesA = GetRecordBatches("a", 100000);
           var streamA = new AsyncRecordBatchStream(GetSchema(), 
recordBatchesA.ToAsyncEnumerable());
   
           var recordBatchesB = GetRecordBatches("b", 100000);
           var streamB = new AsyncRecordBatchStream(GetSchema(), 
recordBatchesA.ToAsyncEnumerable());
   
           unsafe
           {
               CArrowArrayStream* carrowArrayStreamA = 
CArrowArrayStream.Create();
               CArrowArrayStreamExporter.ExportArrayStream(streamA, 
carrowArrayStreamA);
               duckdb_arrow_scan(duckdbConnection, "tableA", 
(IntPtr)carrowArrayStreamA);
   
               CArrowArrayStream* carrowArrayStreamB = 
CArrowArrayStream.Create();
               CArrowArrayStreamExporter.ExportArrayStream(streamB, 
carrowArrayStreamB);
               duckdb_arrow_scan(duckdbConnection, "tableB", 
(IntPtr)carrowArrayStreamB);
           }
   
           Console.WriteLine("Registered streams");
   
           var statement = conn.CreateStatement();
           statement.SqlQuery = "SELECT * FROM tableA join tableB on tableA.id 
= tableB.id";
   
           var result = statement.ExecuteQuery();
   
           using (var stream = result.Stream ?? throw new 
InvalidOperationException("no results found"))
           {
               long totalRows = 0;
               RecordBatch batch;
               while ((batch = stream.ReadNextRecordBatchAsync().Result) != 
null)
               {
                   totalRows += batch.Length;
               }
   
               Console.WriteLine("Row count: " + totalRows);
           }
       }
   
       private static List<RecordBatch> GetRecordBatches(string prefix, int 
size)
       {
           var recordBatches = new List<RecordBatch>();
           var schema = GetSchema();
   
           recordBatches.Add(new RecordBatch(schema, new IArrowArray[]
           {
               new StringArray.Builder().AppendRange(Enumerable.Range(0, 
size).Select(i => prefix + "o" + i )).Build(),
               new StringArray.Builder().AppendRange(Enumerable.Range(0, 
size).Select(i => prefix + "c" + i )).Build(),
               new StringArray.Builder().AppendRange(Enumerable.Range(0, 
size).Select(i => prefix + "p" + i )).Build(),
               new Int32Array.Builder().AppendRange(Enumerable.Range(0, 
size).Select(i => i )).Build(),
           }, size));
   
           return recordBatches;
       }
   
       public static Schema GetSchema()
       {
           return new Schema.Builder()
               .Field(new Field("orderId", StringType.Default, true))
               .Field(new Field("customerId", StringType.Default, true))
               .Field(new Field("productId", StringType.Default, true))
               .Field(new Field("id", Int32Type.Default, true))
               .Build();
       }
   }`


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to