summaryrefslogtreecommitdiff
path: root/megapixels/commands/faiss/build_db.py
blob: 52c4980f6d60f76aa6b44eedc7df51386c06e2ad (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
Load all the CSV files into MySQL
"""

import os
import glob
import click
import time
import pandas as pd

from app.models.sql_factory import engine, SqlDataset
from app.utils.file_utils import load_recipe, load_csv_safe
from app.settings import app_cfg as cfg

@click.command()
@click.pass_context
def cli(ctx):
  """import the various CSVs into MySQL
  """
  load_sql_datasets(clobber=True)

def load_sql_datasets(path, clobber=False):
  datasets = {}
  for path in glob.iglob(os.path.join(cfg.DIR_FAISS_METADATA, "*")):
    dataset = load_sql_dataset(path, clobber)
    datasets[dataset.name] = dataset

def load_sql_dataset(path, clobber=False):
  name = os.path.basename(path)
  dataset = SqlDataset(name)

  for fn in glob.iglob(os.path.join(path, "*.csv")):
    key = os.path.basename(fn).replace(".csv", "")
    table = dataset.get_table(key)
    if table is None:
      continue
    if clobber:
      df = pd.read_csv(fn)

      # fix columns that are named "index", a sql reserved word
      df.columns = table.__table__.columns.keys()

      df.to_sql(name=table.__tablename__, con=engine, if_exists='replace', index=False)

  return dataset