Multimodal Customer Insights Generator with FSDP & Inferentia
Overview This project implements a multimodal AI system that analyzes customer feedback data combining text and images. It employs Fully Sharded Data Parallel (FSDP) training, Reinforcement Learning with Human Feedback (RLHF), and is optimized for AWS Inferentia hardware. The solution is production-ready, featuring SageMaker deployment and CloudWatch monitoring.
Key Features
- Multimodal Model: Combines a CLIP vision model and a LLaMA-3 language model.
- Distributed Training: Utilizes FSDP with CPU offloading for scalable, efficient training.
- RLHF Integration: Enhances model output relevance and quality.
- AWS Inferentia Optimization: Improves inference speed and reduces costs.
- SageMaker Deployment: Supports scalable and reliable deployment.
- Streamlit Dashboard: Offers a user-friendly interface for interacting with the model.
- CloudWatch Monitoring: Tracks inference performance and operational metrics.
Architecture
- Data Ingestion: Customer feedback text and images.
- Preprocessing: Tokenization and image transformation.
- Multimodal Fusion: Combining visual and textual embeddings.
- Model Training: FSDP with RLHF fine-tuning.
- Compilation: Model compiled for AWS Inferentia.
- Deployment: Hosted on SageMaker.
- Monitoring: CloudWatch integration for performance tracking.
- Business Interface: Insights available through a Streamlit dashboard.
Deployment Guide
-
Environment Setup
- Install dependencies from requirements-prod.txt.
- Set environment variables for S3 bucket and SageMaker role.
-
Training with RLHF
- Execute training using the provided Training with RLHF.sh script.
-
Model Compilation
- Compile the model using the Inferentia Compilation.sh script.
-
Deployment
- Deploy the model to SageMaker using the Production-Ready Deployment script.
-
Dashboard Launch
- Start the Streamlit dashboard using Launch Dashboard.sh.
Testing Comprehensive unit tests are included to validate model outputs and training components.
Monitoring The system logs key performance metrics to CloudWatch, with alerts configured for high latency.
Business Value
- Reduced inference costs by up to 65% compared to GPU solutions.
- Improved training scalability and speed.
- Enhanced customer insights through multimodal data analysis.
License Licensed under the MIT License. © Fereydoon Boroojerdi, January 2025.