我有2个数据存储模型:
class KindA(ndb.Model):
field_a1 = ndb.StringProperty()
field_a2 = ndb.StringProperty()
class KindB(ndb.Model):
field_b1 = ndb.StringProperty()
field_b2 = ndb.StringProperty()
key_to_kind_a = ndb.KeyProperty(KindA)
我想查询KindB
并将其输出到csv文件,但如果KindB
的实体指向KindA
中的实体,我希望这些字段也存在于csv中。
如果我能够在变换中使用ndb
,我会像这样设置我的管道
def format(element): # element is an `entity_pb2` object of KindB
try:
obj_a_key_id = element.properties.get('key_to_kind_a', None).key_value.path[0]
except:
obj_a_key_id = None
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<< HOW DO I DO THIS
obj_a = ndb.Key(KindA, obj_a_key_id).get() if obj_a_key_id else None
return ",".join([
element.properties.get('field_b1', None).string_value,
element.properties.get('field_b2', None).string_value,
obj_a.properties.get('field_a1', None).string_value if obj_a else '',
obj_a.properties.get('field_a2', None).string_value if obj_a else '',
]
def build_pipeline(project, start_date, end_date, export_path):
query = query_pb2.Query()
query.kind.add().name = 'KindB'
filter_1 = datastore_helper.set_property_filter(query_pb2.Filter(), 'field_b1', PropertyFilter.GREATER_THAN, start_date)
filter_2 = datastore_helper.set_property_filter(query_pb2.Filter(), 'field_b1', PropertyFilter.LESS_THAN, end_date)
datastore_helper.set_composite_filter(query.filter, CompositeFilter.AND, filter_1, filter_2)
p = beam.Pipeline(options=pipeline_options)
_ = (p
| 'read from datastore' >> ReadFromDatastore(project, query, None)
| 'format' >> beam.Map(format)
| 'write' >> apache_beam.io.WriteToText(
file_path_prefix=export_path,
file_name_suffix='.csv',
header='field_b1,field_b2,field_a1,field_a2',
num_shards=1)
)
return p
我想我可以使用ReadFromDatastore
查询KindA
的所有实体,然后使用CoGroupByKey
合并它们,但KindA
有数百万条记录,这将是非常低效的。
根据这个答案中的建议:qazxsw poi
我创建了以下的utils,它们的灵感来自于源代码
在https://stackoverflow.com/a/49130224/4458510的DatastoreWriteFn
在apache_beam.io.gcp.datastore.v1.datastoreio
的write_mutations
和fetch_entities
apache_beam.io.gcp.datastore.v1.helper
关键是import logging
import time
from socket import error as _socket_error
from apache_beam.metrics import Metrics
from apache_beam.transforms import DoFn, window
from apache_beam.utils import retry
from apache_beam.io.gcp.datastore.v1.adaptive_throttler import AdaptiveThrottler
from apache_beam.io.gcp.datastore.v1.helper import make_partition, retry_on_rpc_error, get_datastore
from apache_beam.io.gcp.datastore.v1.util import MovingSum
from apache_beam.utils.windowed_value import WindowedValue
from google.cloud.proto.datastore.v1 import datastore_pb2, query_pb2
from googledatastore.connection import Datastore, RPCError
_WRITE_BATCH_INITIAL_SIZE = 200
_WRITE_BATCH_MAX_SIZE = 500
_WRITE_BATCH_MIN_SIZE = 10
_WRITE_BATCH_TARGET_LATENCY_MS = 5000
def _fetch_keys(project_id, keys, datastore, throttler, rpc_stats_callback=None, throttle_delay=1):
req = datastore_pb2.LookupRequest()
req.project_id = project_id
for key in keys:
req.keys.add().CopyFrom(key)
@retry.with_exponential_backoff(num_retries=5, retry_filter=retry_on_rpc_error)
def run(request):
# Client-side throttling.
while throttler.throttle_request(time.time() * 1000):
logging.info("Delaying request for %ds due to previous failures", throttle_delay)
time.sleep(throttle_delay)
if rpc_stats_callback:
rpc_stats_callback(throttled_secs=throttle_delay)
try:
start_time = time.time()
response = datastore.lookup(request)
end_time = time.time()
if rpc_stats_callback:
rpc_stats_callback(successes=1)
throttler.successful_request(start_time * 1000)
commit_time_ms = int((end_time - start_time) * 1000)
return response, commit_time_ms
except (RPCError, _socket_error):
if rpc_stats_callback:
rpc_stats_callback(errors=1)
raise
return run(req)
# Copied from _DynamicBatchSizer in apache_beam.io.gcp.datastore.v1.datastoreio
class _DynamicBatchSizer(object):
"""Determines request sizes for future Datastore RPCS."""
def __init__(self):
self._commit_time_per_entity_ms = MovingSum(window_ms=120000, bucket_ms=10000)
def get_batch_size(self, now):
"""Returns the recommended size for datastore RPCs at this time."""
if not self._commit_time_per_entity_ms.has_data(now):
return _WRITE_BATCH_INITIAL_SIZE
recent_mean_latency_ms = (self._commit_time_per_entity_ms.sum(now) / self._commit_time_per_entity_ms.count(now))
return max(_WRITE_BATCH_MIN_SIZE,
min(_WRITE_BATCH_MAX_SIZE,
_WRITE_BATCH_TARGET_LATENCY_MS / max(recent_mean_latency_ms, 1)))
def report_latency(self, now, latency_ms, num_mutations):
"""Reports the latency of an RPC to Datastore.
Args:
now: double, completion time of the RPC as seconds since the epoch.
latency_ms: double, the observed latency in milliseconds for this RPC.
num_mutations: int, number of mutations contained in the RPC.
"""
self._commit_time_per_entity_ms.add(now, latency_ms / num_mutations)
class LookupKeysFn(DoFn):
"""A `DoFn` that looks up keys in the Datastore."""
def __init__(self, project_id, fixed_batch_size=None):
self._project_id = project_id
self._datastore = None
self._fixed_batch_size = fixed_batch_size
self._rpc_successes = Metrics.counter(self.__class__, "datastoreRpcSuccesses")
self._rpc_errors = Metrics.counter(self.__class__, "datastoreRpcErrors")
self._throttled_secs = Metrics.counter(self.__class__, "cumulativeThrottlingSeconds")
self._throttler = AdaptiveThrottler(window_ms=120000, bucket_ms=1000, overload_ratio=1.25)
self._elements = []
self._batch_sizer = None
self._target_batch_size = None
def _update_rpc_stats(self, successes=0, errors=0, throttled_secs=0):
"""Callback function, called by _fetch_keys()"""
self._rpc_successes.inc(successes)
self._rpc_errors.inc(errors)
self._throttled_secs.inc(throttled_secs)
def start_bundle(self):
"""(re)initialize: connection with datastore, _DynamicBatchSizer obj"""
self._elements = []
self._datastore = get_datastore(self._project_id)
if self._fixed_batch_size:
self._target_batch_size = self._fixed_batch_size
else:
self._batch_sizer = _DynamicBatchSizer()
self._target_batch_size = self._batch_sizer.get_batch_size(time.time()*1000)
def process(self, element):
"""Collect elements and process them as a batch"""
self._elements.append(element)
if len(self._elements) >= self._target_batch_size:
return self._flush_batch()
def finish_bundle(self):
"""Flush any remaining elements"""
if self._elements:
objs = self._flush_batch()
for obj in objs:
yield WindowedValue(obj, window.MAX_TIMESTAMP, [window.GlobalWindow()])
def _flush_batch(self):
"""Fetch all of the collected keys from datastore"""
response, latency_ms = _fetch_keys(
project_id=self._project_id,
keys=self._elements,
datastore=self._datastore,
throttler=self._throttler,
rpc_stats_callback=self._update_rpc_stats)
logging.info("Successfully read %d keys in %dms.", len(self._elements), latency_ms)
if not self._fixed_batch_size:
now = time.time()*1000
self._batch_sizer.report_latency(now, latency_ms, len(self._elements))
self._target_batch_size = self._batch_sizer.get_batch_size(now)
self._elements = []
return [entity_result.entity for entity_result in response.found]
class LookupEntityFieldFn(LookupKeysFn):
"""
Looks-up a field on an EntityPb2 object
Expects a EntityPb2 object as input
Outputs a tuple, where the first element is the input object and the second element is the object found during the
lookup
"""
def __init__(self, project_id, field_name, fixed_batch_size=None):
super(LookupEntityFieldFn, self).__init__(project_id=project_id, fixed_batch_size=fixed_batch_size)
self._field_name = field_name
@staticmethod
def _pb2_key_value_to_tuple(kv):
"""Converts a key_value object into a tuple, so that it can be a dictionary key"""
path = []
for p in kv.path:
path.append(p.name)
path.append(p.id)
return tuple(path)
def _flush_batch(self):
_elements = self._elements
keys_to_fetch = []
for element in self._elements:
kv = element.properties.get(self._field_name, None)
if kv and kv.key_value and kv.key_value.path:
keys_to_fetch.append(kv.key_value)
self._elements = keys_to_fetch
read_keys = super(LookupEntityFieldFn, self)._flush_batch()
_by_key = {self._pb2_key_value_to_tuple(entity.key): entity for entity in read_keys}
output_pairs = []
for input_obj in _elements:
kv = input_obj.properties.get(self._field_name, None)
output_obj = None
if kv and kv.key_value and kv.key_value.path:
output_obj = _by_key.get(self._pb2_key_value_to_tuple(kv.key_value), None)
output_pairs.append((input_obj, output_obj))
return output_pairs
线,其中:
response = datastore.lookup(request)
(来自datastore = get_datastore(project_id)
)apache_beam.io.gcp.datastore.v1.helper.get_datastore
是来自request
的LookupRequest
google.cloud.proto.datastore.v1.datastore_pb2
是来自response
的LookupResponse
上面代码的其余部分做了类似的事情:
(老实说,我不知道这些位有多重要,我在浏览apache_beam源代码时遇到过它们)
生成的util函数google.cloud.proto.datastore.v1.datastore_pb2
是一个LookupEntityFieldFn(project_id, field_name)
,它接受一个DoFn
对象作为输入,提取并获取/获取驻留在字段entity_pb2
上的key_property,并将结果输出为元组(fetch-result与输入对象配对)
然后我的管道代码变成了
field_name