summaryrefslogtreecommitdiff
path: root/ricky
diff options
context:
space:
mode:
Diffstat (limited to 'ricky')
-rw-r--r--ricky/config.py1
-rw-r--r--ricky/dataset.py32
-rw-r--r--ricky/params/__init__.py34
-rw-r--r--ricky/utils.py2
4 files changed, 16 insertions, 53 deletions
diff --git a/ricky/config.py b/ricky/config.py
index 1cf3255..7f8715f 100644
--- a/ricky/config.py
+++ b/ricky/config.py
@@ -10,7 +10,6 @@ TEST_URL = (
)
PB_BASE = "http://asdf.us/"
-
def _add_pb_base(path):
return os.path.join(PB_BASE, path)
diff --git a/ricky/dataset.py b/ricky/dataset.py
index 4f8a422..478ee5e 100644
--- a/ricky/dataset.py
+++ b/ricky/dataset.py
@@ -3,24 +3,6 @@ from ricky.utils import data_from_image
from pybrain.datasets import SupervisedDataSet
-# while subclassing this works, we should try to detect the length of params
-# and build a new data set for each type of params set...
-# therefore, an instance of SupervisedDataSet could actually be
-# accessed through the params instance...simplified one-to-one mapping
-
-# we are limited to only one classifier per params instance as well
-# however this is sort of a good thing, because built into the params
-# class can be a method that randomizes params, and then evaluates
-
-# we might be able to get this done through multiple inheritance
-# keep all dataset related stuff in a separate class to make it better organized
-
-# we need
-# .evaluate
-# .generate_liked_image
-# .train_from_url_list
-# .reset
-
class DataSet(SupervisedDataSet):
@@ -35,11 +17,9 @@ class DataSet(SupervisedDataSet):
target = 1
data_list = [data_from_image(image) for image in url_list if image]
for data in data_list:
- for params_class in ricky.params.Params.__subclasses__():
- if data['module'] == params_class.__name__:
- params_instance = params_class()
- params_instance.from_dict(data['params'])
- self.addSample(
- params_instance.as_normalized(),
- target
- )
+ params_instance = Params.new_class_from_classname(data['module'])
+ params_instance.from_dict(data['params'])
+ self.addSample(
+ params_instance.as_normalized(),
+ target
+ )
diff --git a/ricky/params/__init__.py b/ricky/params/__init__.py
index 53ac530..82cbc79 100644
--- a/ricky/params/__init__.py
+++ b/ricky/params/__init__.py
@@ -25,27 +25,14 @@ class Params(object):
def __len__(self):
return len(self._params)
+
def _load_probabilities_json(self, probabilities_file=None):
- if probabilities_file:
- filepath = probabilities_file
- else:
- filepath = os.path.join(
+ filepath = probabilities_file or \
+ os.path.join(
PROBABILITIES_DIR,
"%s.json" % (self.__class__.__name__)
)
- try:
- f = open(filepath, 'r')
- data = f.read()
- f.close()
- return json.loads(data)
- except json.scanner.JSONDecodeError as e:
- sys.stderr.write("Invalid Json - Problem decoding %s\n" % filepath)
- sys.stderr.write("%s\n" % e)
- sys.exit(1)
- except IOError:
- sys.stderr.write(
- "Could not find probabilities file %s\n" % filepath)
- sys.exit(1)
+ return json.load(open(filepath))
def randomize(
self,
@@ -65,7 +52,6 @@ class Params(object):
param.randomize(probabilities=probabilities_dict.get(param.name))
-
def execute(self):
"""calls the associated api"""
if OFFLINE:
@@ -94,12 +80,6 @@ class Params(object):
result[param.name] = param.value
return result
- def as_normalized(self):
- return tuple([
- {'name': param.name, 'normalized': param.as_normalized()}
- for param in self._params
- ])
-
def as_serialized(self):
"""
returns params in serialized form to use in a dataset
@@ -120,7 +100,11 @@ class Params(object):
param.value = params_dict[param.name]
@classmethod
- def from_classname(cls, classname):
+ def new_class_from_classname(cls, classname):
+ """
+ #FIXME make this class a plugin parent class
+ anything else look weird here?
+ """
for subclass in cls.__subclasses__():
if subclass.__name__ == classname:
return subclass()
diff --git a/ricky/utils.py b/ricky/utils.py
index 04e2074..98dca5f 100644
--- a/ricky/utils.py
+++ b/ricky/utils.py
@@ -17,7 +17,7 @@ def data_from_url(url):
result = ImCmd.search(newfile=newfile).first()
try:
return {
- "module": result.tag.split(":")[0],
+ "module": result.tag.split(":")[0],
"params": json.loads(result.dataobj)
}
except AttributeError: