summaryrefslogtreecommitdiff
path: root/megapixels/app/models/sql_factory.py
blob: 4adc6f48f731307f7a2757e3398deb4b5b24074f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os

from sqlalchemy import create_engine, Table, Column, String, Integer, DateTime, Float
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.ext.declarative import AbstractConcreteBase, ConcreteBase

connection_url = "mysql+mysqldb://{}:{}@{}/{}".format(
  os.getenv("DB_USER"),
  os.getenv("DB_PASS"),
  os.getenv("DB_HOST"),
  os.getenv("DB_NAME")
)

engine = create_engine(connection_url)
Session = sessionmaker(bind=engine)
session = Session()
Base = declarative_base(engine)

class SqlDataset:
  def __init__(self, name):
    self.name = name
    self.tables = {}

  def get_table(self, type):
    if type in self.tables:
      return self.tables[type]
    elif type == 'uuid':
      self.tables[type] = self.uuid_table()
    elif type == 'roi':
      self.tables[type] = self.roi_table()
    elif type == 'identity_meta':
      self.tables[type] = self.identity_table()
    elif type == 'pose':
      self.tables[type] = self.pose_table()
    else:
      return None
    return self.tables[type]

  # ==> uuids.csv <==
  # index,uuid
  # 0,f03fd921-2d56-4e83-8115-f658d6a72287
  def uuid_table(self):
    class UUID(Base):
      __tablename__ = self.name + "_uuid"
      id = Column(Integer, primary_key=True)
      uuid = Column(String(36), nullable=False)
    return UUID

  # ==> roi.csv <==
  # index,h,image_height,image_index,image_width,w,x,y
  # 0,0.33000000000000007,250,0,250,0.32999999999999996,0.33666666666666667,0.35
  def roi_table(self):
    class ROI(Base):
      __tablename__ = self.name + "_roi"
      id = Column(Integer, primary_key=True)
      h = Column(Float, nullable=False)
      image_height = Column(Integer, nullable=False)
      image_index = Column(Integer, nullable=False)
      image_width = Column(Integer, nullable=False)
      w = Column(Float, nullable=False)
      x = Column(Float, nullable=False)
      y = Column(Float, nullable=False)
    return ROI

  # ==> identity.csv <==
  # index,fullname,description,gender,images,image_index
  # 0,A. J. Cook,Canadian actress,f,1,0
  def identity_table(self):
    class Identity(Base):
      __tablename__ = self.name + "_identity"
      id = Column(Integer, primary_key=True)
      fullname = Column(String(36), nullable=False)
      description = Column(String(36), nullable=False)
      gender = Column(String(1), nullable=False)
      images = Column(Integer, nullable=False)
      image_id = Column(Integer, nullable=False)
    return Identity

  # ==> pose.csv <==
  # index,image_index,pitch,roll,yaw
  # 0,0,11.16264458441435,10.415885631337728,22.99719032415318
  def pose_table(self):
    class Pose(Base):
      __tablename__ = self.name + "_pose"
      id = Column(Integer, primary_key=True)
      image_id = Column(Integer, primary_key=True)
      pitch = Column(Float, nullable=False)
      roll = Column(Float, nullable=False)
      yaw = Column(Float, nullable=False)
    return Pose