diff --git a/cdqa/pipeline/cdqa_sklearn.py b/cdqa/pipeline/cdqa_sklearn.py index a15f01b..ace6770 100644 --- a/cdqa/pipeline/cdqa_sklearn.py +++ b/cdqa/pipeline/cdqa_sklearn.py @@ -35,13 +35,13 @@ class QAPipeline(BaseEstimator): -------- >>> from cdqa.pipeline import QAPipeline >>> qa_pipeline = QAPipeline(reader='bert_qa_squad_vCPU-sklearn.joblib') - >>> qa_pipeline.fit_retriever(X=df) - >>> prediction = qa_pipeline.predict(X='When BNP Paribas was created?') + >>> qa_pipeline.fit_retriever(df=df) + >>> prediction = qa_pipeline.predict(query='When BNP Paribas was created?') >>> from cdqa.pipeline import QAPipeline >>> qa_pipeline = QAPipeline() >>> qa_pipeline.fit_reader('train-v1.1.json') - >>> qa_pipeline.fit_retriever(X=df) + >>> qa_pipeline.fit_retriever(df=df) >>> prediction = qa_pipeline.predict(X='When BNP Paribas was created?') """ @@ -140,7 +140,7 @@ def predict( Parameters ---------- - X: str + query: str Sample (question) to perform a prediction on n_predictions: int or None (default: None).