featzhang created FLINK-39074:
---------------------------------
Summary: Add built-in AI model inference capability to PyFlink
with automatic lifecycle management
Key: FLINK-39074
URL: https://issues.apache.org/jira/browse/FLINK-39074
Project: Flink
Issue Type: New Feature
Components: API / Python
Reporter: featzhang
h3. Overview
PyFlink currently lacks native support for AI model inference, forcing users to
manually manage models, batching, and resources. This proposal introduces a new
{{DataStream.infer()}} API that provides out-of-the-box AI inference
capabilities with automatic lifecycle management.
h3. Motivation
*Current Pain Points:*
1. Users must manually load/unload models
2. No built-in batching for inference optimization
3. No standardized resource management
4. Lack of warmup and performance optimization strategies
*User Impact:*
- Complicated boilerplate code for inference
- Suboptimal performance due to lack of batching
- Resource leaks from improper model management
h3. Proposed Solution
h4. 1. Simple API
# {{from pyflink.datastream import StreamExecutionEnvironment}}
# {{}}
# {{env = StreamExecutionEnvironment.get_execution_environment()}}
# {{}}
# {{# Text embedding example}}
# {{result = data_stream.infer(}}
# {{ model="sentence-transformers/all-MiniLM-L6-v2",}}
# {{ input_col="text",}}
# {{ output_col="embedding"}}
# {{)}}
# {{}}
# {{# Sentiment classification example}}
# {{sentiment = data_stream.infer(}}
# {{ model="distilbert-base-uncased-finetuned-sst-2-english",}}
# {{ input_col="text",}}
# {{ output_col="sentiment",}}
# {{ task_type="classification"}}
# {{)}}
h4. 2. Architecture
# {{DataStream.infer()}}
# {{ ↓}}
# {{InferenceFunction (MapFunction)}}
# {{ ↓}}
# {{ModelLifecycleManager}}
# {{ ├── Model Loading (HuggingFace/Local)}}
# {{ ├── Model Warmup}}
# {{ └── Resource Management}}
# {{ ↓}}
# {{BatchInferenceExecutor}}
# {{ ├── Tokenization}}
# {{ ├── Batch Inference}}
# {{ └── Result Extraction}}
# {{ ↓}}
# {{InferenceMetrics}}
h4. 3. Key Features
*Model Lifecycle Management*
- Automatic model loading from HuggingFace Hub or local path
- Model warmup for optimal performance
- Proper cleanup and resource deallocation
*Batch Inference*
- Configurable batch size
- Batch timeout control
- Future: Integration with AsyncBatchFunction (FLINK-38825)
*Multi-Task Support*
- Text embedding
- Text classification
- Text generation
*Resource Optimization*
- CPU/GPU device selection
- FP16 precision support
- CUDA memory management
*Metrics & Monitoring*
- Inference latency (avg/p50/p95/p99)
- Throughput tracking
- Error rate monitoring
h4. 4. Configuration Options
||Parameter||Type||Default||Description||
|model|string|-|Model name (HuggingFace) or local path (required)|
|input_col|string|-|Input column name (required)|
|output_col|string|-|Output column name (required)|
|batch_size|int|32|Batch size for inference|
|max_batch_timeout_ms|int|100|Max batch wait time|
|model_warmup|bool|true|Enable model warmup|
|device|string|"cpu"|Device: cpu, cuda:0, etc.|
|num_workers|int|1|Number of worker processes|
|task_type|string|"embedding"|Task type: embedding/classification/generation|
h3. Implementation Status
*✅ Completed (Phase 1 - MVP)*
Python Modules:
- {{pyflink.ml.inference.config}} - Configuration management
- {{pyflink.ml.inference.lifecycle}} - Model lifecycle management
- {{pyflink.ml.inference.executor}} - Batch inference execution
- {{pyflink.ml.inference.function}} - MapFunction implementation
- {{pyflink.ml.inference.metrics}} - Metrics collection
- {{pyflink.datastream.data_stream.infer()}} - Public API
Unit Tests:
- Configuration validation
- Metrics collection
- Mock-based inference tests
*⏳ In Progress / Planned*
Phase 2:
- [ ] Java-side InferenceOperator for better integration
- [ ] AsyncBatchFunction integration (depends on FLINK-38825)
- [ ] Python Worker pool management
Phase 3:
- [ ] ONNX Runtime support
- [ ] TensorFlow model support
- [ ] Model quantization
- [ ] Multi-model pipeline
h3. Dependencies
*Required Python Packages:*
# {{torch>=2.0.0}}
# {{transformers>=4.30.0}}
# {{numpy>=1.21.0}}
*Optional:*
# {{cuda-python>=11.0 # For GPU support}}
# {{onnxruntime>=1.14.0 # For ONNX models}}
h3. Code Statistics
# {{Module Lines}}
# {{--------------------------------}}
# {{config.py 90}}
# {{lifecycle.py 188}}
# {{executor.py 228}}
# {{function.py 171}}
# {{metrics.py 96}}
# {{__init__.py 50}}
# {{tests/test_inference.py 278}}
# {{--------------------------------}}
# {{Total 1,101 lines}}
h3. Testing
*Unit Tests:* ✅ 11 tests passing
- Configuration validation
- Metrics calculation
- Mock inference logic
*Integration Tests:* ⏳ Planned
- End-to-end inference with real models
- GPU inference tests
- Performance benchmarks
h3. Performance Expectations ?
||Scenario||Current (manual)||With .infer()||Improvement||
|CPU (batch=1)|10 rec/s|100 rec/s|*10x*|
|CPU (batch=32)|50 rec/s|500 rec/s|*10x*|
|GPU (batch=64)|200 rec/s|2000 rec/s|*10x*|
h3. Known Limitations (Phase 1)
# *Batching:* Currently processes records one-by-one via MapFunction. True
batching requires AsyncBatchFunction integration (FLINK-38825).
# *Java Integration:* Pure Python implementation. Java-side operator not
implemented yet.
# *Model Support:* Currently HuggingFace Transformers only. ONNX/TensorFlow
support planned.
h3. Documentation
*Planned Documentation:*
- User guide with examples
- API reference
- Best practices guide
- Performance tuning guide
h3. Risk Assessment
*Technical Risks:*
- Python/Java serialization overhead
- Memory management complexity
- GPU resource contention
*Mitigation:*
- Use Arrow format for efficient serialization
- Implement memory monitoring
- Provide GPU isolation options
h3. References
* [HuggingFace Transformers|https://huggingface.co/docs/transformers]
* [PyTorch Inference
Guide|https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html]
* Feature Design Doc
--
This message was sent by Atlassian Jira
(v8.20.10#820010)